Как работает tensor parallelism с FP8 в vLLM?

Краткий тезис

parallelism|Tensor parallelism с FP8 в vLLM — это техника распределённого инференса, при которой веса модели и активации хранятся в 8-битном формате с плавающей точкой (FP8), а коммуникация между GPU (например, AllReduce) выполняется в FP16. Это позволяет ускорить вычисления за счёт снижения объёма передаваемых данных и более быстрых матричных операций на H100 (поддержка FP8). На практике vLLM автоматически конвертирует тензоры в FP8 перед вычислениями и обратно в FP16 для коммуникации, что даёт прирост производительности до 1.8x при batch=128.


1. Термин: Tensor Parallelism (TP)

parallelism|Tensor parallelism — это способ распределения модели по нескольким GPU, при котором каждый слой разбивается на части (например, веса матрицы делятся на несколько частей) и каждая часть обрабатывается на отдельном GPU. Это позволяет ускорить инференс за счёт параллельной обработки, но требует синхронизации результатов (AllReduce).

Зачем нужен TP

  • Уменьшение времени инференса за счёт параллельной обработки.
  • Возможность работы с моделями, которые не помещаются в память одной GPU.

Как работает TP в vLLM

  • Модель разбивается на "шарды" (shards) по числу GPU.
  • Каждый GPU обрабатывает свою часть данных.
  • Результаты объединяются через AllReduce.

2. Термин: FP8 (8-bit Floating Point)

FP8 — это 8-битный формат чисел с плавающей точкой, который обеспечивает более высокую скорость вычислений и меньшее потребление памяти по сравнению с FP16 или FP32. В контексте vLLM FP8 используется для хранения весов и активаций.

Особенности FP8

  • Меньшая точность (2-3 значащих цифры), но достаточная для инференса.
  • Поддержка на H100 (Hopper) и более новых GPU.
  • Ускорение за счёт специализированных тензорных ядер.

Проблемы FP8

  • Ограниченный динамический диапазон (может привести к переполнению или потере точности).
  • Требует калибровки (scaling factors) для корректной работы.

3. Как vLLM использует FP8 для Tensor Parallelism

vLLM автоматически конвертирует веса модели в FP8 при загрузке (если включён режим FP8). При этом:

  • Веса хранятся в FP8.
  • Активации (промежуточные тензоры) также могут быть в FP8.
  • Коммуникация между GPU (AllReduce) выполняется в FP16 (или FP32) для сохранения точности.

Почему коммуникация в FP16, а не FP8:

  • AllReduce требует суммирования тензоров от всех GPU. При суммировании в FP8 возможна потеря точности из-за ограниченного диапазона.
  • FP16 обеспечивает достаточную точность для агрегации.

Процесс

  1. Каждый GPU вычисляет свою часть результата (логиты) в FP8.
  2. Перед AllReduce логиты преобразуются в FP16.
  3. AllReduce суммирует FP16-тензоры от всех GPU.
  4. Результат преобразуется обратно в FP8 для следующего слоя (или для вывода).

4. AllReduce в контексте TP с FP8

AllReduce — это операция, которая суммирует тензоры от всех GPU и распределяет результат обратно. В vLLM с FP8:

  • Вход AllReduce FP16-тензоры (преобразованные из FP8).
  • Выход AllReduce FP16-тензор, который затем преобразуется в FP8.

Overhead коммуникации

  • Размер передаваемых данных такой же, как при FP16 (так как AllReduce работает с FP16).
  • Время коммуникации не увеличивается, но вычисления ускоряются за счёт FP8.

Пример конфигурации vLLM для TP с FP8:

from vllm import LLM, SamplingParams

llm = LLM(
    model="meta-llama/Llama-2-7b-hf",
    tensor_parallel_size=4,  # 4 GPU
    dtype="fp8",             # Включение FP8
    max_num_batched_tokens=4096,
)

5. Ускорение: сравнение FP8 vs FP16

Бенчмарки (на H100, batch=128):

Причины ускорения

  • FP8 требует меньше памяти, что снижает bottleneck памяти.
  • Тензорные ядра H100 работают быстрее с FP8.
  • Меньший объём данных для передачи (но в vLLM коммуникация в FP16, поэтому выигрыш только в compute).

Ограничения

  • Ускорение зависит от batch size: при малых batch (1-16) выигрыш меньше.
  • Не все операции поддерживают FP8 (например, softmax остаётся в FP32).

6. Когда использовать FP8 с TP

Рекомендуется

  • Большие модели (70B+), которые требуют распределения по нескольким GPU.
  • Высокая нагрузка (batch > 64).
  • Задачи, где latency критична (чат-боты, real-time API).

Не рекомендуется

  • Малые модели (7B-13B), которые помещаются на одну GPU.
  • Высокая точность (например, научные расчёты).

7. Проблемы и ограничения

Проблемы

  • Калибровка Для FP8 нужны scaling factors, которые подбираются на калибровочном датасете. Если данные отличаются, возможна потеря точности.
  • Совместимость Не все GPU поддерживают FP8 (только H100 и новее).
  • Overhead конвертации Преобразование FP8 ↔ FP16 занимает время, но оно компенсируется ускорением вычислений.

Ограничения vLLM

  • FP8 поддерживается только для линейных слоёв (attention, FFN).
  • Остальные операции (embedding, layernorm) остаются в FP16/FP32.

8. Пример кода: запуск vLLM с FP8 и TP

from vllm import LLM, SamplingParams
import torch

# Параметры
model_name = "meta-llama/Llama-2-70b-hf"
tp_size = 4  # 4 GPU
batch_size = 128

# Инициализация с FP8
llm = LLM(
    model=model_name,
    tensor_parallel_size=tp_size,
    dtype="fp8",  # Включение FP8
    max_num_batched_tokens=4096,
    trust_remote_code=True,
)

# Генерация
prompts = ["Hello, world!"] * batch_size
sampling_params = SamplingParams(temperature=0.7, max_tokens=100)
outputs = llm.generate(prompts, sampling_params)

# Вывод
for output in outputs:
    print(output.outputs[0].text)

Ожидаемый результат Ускорение по сравнению с FP16.


9. Пет-проект для закрепления

Задача Сравнить производительность vLLM с FP8 и FP16 на модели LLaMA-2-70B с TP=4.

Инструменты

  • vLLM (установка: pip install vllm).
  • Бенчмарк-скрипт (time, throughput).

Шаги:

  1. Запустить инференс с FP16 (dtype="float16") и замерить latency/throughput.
  2. Запустить с FP8 (dtype="fp8") и замерить.
  3. Сравнить результаты (таблица).

Ожидаемый результат

Таблица сравнения

МетрикаFP16FP8Ускорение
Latency (ms)150831.8x
Throughput (req/s)1001801.8x
Perplexity5.25.210.01 diff

10. Связь с другими вопросами

ВопросТема
459Как работает pipeline parallelism в vLLM?
461Как работает sequence parallelism в vLLM?
462Как работает quantization в vLLM?
463Как работает speculative decoding в vLLM?
464Как работает prefix caching в vLLM?

11. Навигация


Навигация