Почему decode stage плохо batchится?

Краткий тезис

Decode stage (этап автокорреляции) при генерации каждого следующего токена требует значительного объёма оперативной памяти из-за постоянного чтения и записи кэша KV (KV cache) для каждого запроса. Современные LLM при декодировании становятся memory-bound: пропускная способность памяти (HBM) является узким местом, а не вычислительные ресурсы (FLOPS). batching|Static batching (пакетная обработка с фиксированным размером) усугубляет ситуацию, так как дожидается самого медленного запроса. batching|Continuous batching (batching|динамическое пакетирование) повышает утилизацию, но не устраняет фундаментальную проблему — генерация остаётся memory-bound при разумном размере батча. Лишь при очень большом batch (>32 на H100) можно перейти к compute-bound режиму, но это редко достижимо в реальных сценариях.


1. Ключевые термины

  • Decode stage: этап инференса LLM, когда модель генерирует токены один за другим (autoregressive decoding). После того, как prefill stage обработал все входные токены параллельно, decode stage выполняется итеративно: на каждом шаге генерируется один новый токен, и вычисляется скрытое состояние.
  • Batch (батч): группа запросов, обрабатываемых на GPU одновременно. Позволяет увеличить throughput, утилизируя параллелизм.
  • Memory-bound: ситуация, когда производительность ограничена пропускной способностью памяти (bandwidth), а не вычислительной мощностью (FLOPS). Пайплайн простаивает, ожидая загрузки данных из HBM.
  • Compute-bound: производительность ограничена вычислительными возможностями (тензорные ядра, CUDA cores). GPU полностью утилизирована в вычислениях.
  • KV cache: кэш ключей и значений для ранее сгенерированных токенов. Хранится в HBM (High Bandwidth Memory) GPU. Его размер растёт линейно с длиной последовательности и числом запросов.

2. Анатомия autoregressive decoding

Каждый шаг декодирования включает:

  1. Загрузку KV cache для предыдущих токенов (всех! — O(sequence_length * d_model) для каждого слоя).
  2. Матричное умножение нового скрытого состояния с матрицами Q, K, V (небольшая часть FLOPS).
  3. Обновление KV cache: запись нового K, V в HBM.
  4. Softmax и Attention (основные вычисления, но они тоже требуют чтения KV cache).

Соотношение операций: на один выданный токен требуется чтение и запись KV cache в HBM. Например, для модели LLaMA-7B с размерностью 4096, длиной контекста 2048 и batch=1, KV cache занимает около 1 MB на слой (в fp16), а всего ~32 слоёв — 32 MB. При batch=16 уже 512 MB. Но главное — каждый шаг нужно загрузить всё это в SRAM. Пропускная способность памяти H100 (~3.35 TB/s) становится лимитом.

Формула затрат памяти на шаг:

Memory accesses per step ≈ (batch_size × num_layers × 2 × d_model × seq_len) * (read KV) + (batch_size × num_layers × 2 × d_model) * (write new KV)

Упрощённо: для каждого запроса на каждом шаге мы читаем весь его предыдущий KV cache (O(seq_len)) и записываем один новый элемент. Тем самым объём данных ~ O(batch × seq_len), а не O(batch × d_model), как в prefill.


3. Проблема static batching (фиксированного батча)

Static batching собирает несколько запросов в один фиксированный батч и обрабатывает их синхронно. Пока все запросы не завершат генерацию, GPU не освобождается. Недостатки:

  • Разная длина генерации: один запрос может генерировать 50 токенов, другой — 500. Самый короткий ждёт самого длинного. GPU простаивает (дожидается последнего).
  • KV cache «хвостов»: для коротких запросов память выделяется на весь максимальный контекст, хотя реально она не используется.
  • Низкая утилизация: после того как некоторые запросы завершились, ячейки батча остаются пустыми, но не освобождаются до конца батча.
ХарактеристикаStatic batchingContinuous batching
СинхронизацияВсе запросы стартуют и завершают вместеЗапросы могут подключаться/отключаться в любой момент
ОжиданиеЖдёт самого медленногоНе ждёт — завершённые удаляются, новые добавляются
Использование памятиФиксированный аллокаторДинамическое выделение (PagedAttention)
ThroughputНизкий при разной длинеВысокий (почти идеальная утилизация)
LatencyВысокая для коротких запросовМожет снижаться из-за конкуренции за вычисления

4. Continuous batching (динамическое пакетирование)

Техника, реализованная в vLLM, TensorRT-LLM, Hugging Face TGI. Каждый шаг декодирования выбирает из очереди несколько запросов, которые действительно готовы генерировать новый токен. Запросы, завершившие генерацию, сразу удаляются из батча; новые запросы (из prefill) могут добавляться.

  • Преимущество: утилизация GPU растёт (больше активных запросов в среднем), пропускная способность (throughput) увеличивается в 2-4 раза по сравнению со static batching.
  • Ограничение: каждый шаг всё равно memory-bound. Даже при полной загрузке батча (например, 64 запроса) — узким местом остаётся пропускная способность памяти при чтении/записи KV cache.

Техническая реализация — PagedAttention (vLLM): KV cache разбивается на страницы (блоки по 16-32 токена), которые выделяются по мере необходимости. Это уменьшает внутреннюю фрагментацию и позволяет эффективнее использовать HBM.


5. Почему decode stage остаётся memory-bound даже при continuous batching

  1. Арифметическая интенсивность низкая: на каждый байт, загруженный из HBM (KV cache), выполняется всего несколько FLOPS.
    Arithmetic intensity ≈ (операций) / (байты, переданные из HBM).

    Для decode она составляет единицы (например, 1-4 FLOP/байт), тогда как для prefill — сотни (100-1000 FLOP/байт).
    Пороговое значение для H100 (на котором memory-bound переходит в compute-bound) — около 100-150 FLOP/байт. Decode далёк от этого.

  2. KV cache постоянно растёт: при длинном контексте (например, 4K токенов) каждый запрос требует ~ несколько сотен МБ. Даже с PagedAttention приходится читать сотни страниц на шаг.

  3. Современные архитектуры (FlashAttention 2/3) оптимизируют prefill, делая его compute-bound, но для decode основной выигрыш — только за счёт редукции объёма пересылаемых данных (Multi-Query Attention, Grouped Query Attention).


6. Когда decode может стать compute-bound?

Теоретически при очень большом batch (>32 на H100) начинает доминировать вычислительная часть (матричные умножения). Например, если batch_size = 128 и seq_len = 1024, то KV cache на шаг ~ 128 × 1024 × d_model × слои — гигантский, но GPU может загрузить параллельно. При этом количество FLOPS для attention (softmax + weighted sum) может стать сопоставимым с объёмом памяти.

Однако на практике:

  • Размер батча ограничен доступным VRAM (даже H100 80GB не вместит 128 запросов с длинным контекстом).
  • Предел HBM bandwidth (~3.35 TB/s) — при batch 128 и seq_len=4K, каждый шаг «съедает» более 2 ТБ/с только на чтение KV cache. Остаётся запас около 1 ТБ/с для вычислений, что ещё не делает задачу compute-bound.

Вывод: в продакшене (batch до 64, контекст до 8K) decode stage остаётся deeply memory-bound.


7. Сравнение prefill vs decode

ПараметрPrefill stageDecode stage
ДлительностьОдна итерация (parallel)Много шагов (autoregressive)
Основная операцияAttention над всеми токенамиAttention + генерация 1 токена
Арифметическая интенсивностьВысокая (100-1000 FLOP/byte)Низкая (1-10 FLOP/byte)
BottleneckCompute (FLOPs)Memory (HBM bandwidth)
Выигрыш от batchingЛинейный рост throughputСублинейный (быстро упирается в bandwidth)
Влияние continuous batchingНезначительное (prefill обычно вне очереди)Ключевое (утилизация)

8. Влияние на latency и throughput

  • Latency (задержка): для одного запроса decode stage memory-bound → время генерации токена почти не зависит от batch, пока батч помещается в пропускную способность памяти. Добавление ещё одного запроса в батч не увеличит заметно время одного шага (линейный рост памяти, но bandwidth делится). Парадокс: в continuous batching latency может возрасти из-за конкуренции, но по факту при правильной планировке сохраняется стабильной.
  • Throughput (пропускная способность): растёт с размером батча, но с насыщением. График: резкий подъём при batch 1 → 4, потом замедление, после batch 32 почти плато.

9. Техники смягчения проблемы

  1. Multi-Query Attention (MQA) и Grouped Query Attention (GQA): уменьшают размер KV cache (несколько ключей/значений на много запросов), тем самым снижают объём памяти на шаг и переводят decode в чуть более compute-bound.
  2. Speculative Decoding: использует маленький «черновик» (draft model) для генерации нескольких токенов за один шаг, затем big model проверяет. Фактически уменьшает количество decode steps, сокращая общее время ожидания памяти.
  3. PagedAttention (vLLM) и FlashDecoding: оптимизируют чтение/запись KV cache, частично скрывают latency памяти за счёт асинхронности и tiling.
  4. KV cache compression (квантование, обрезание, спорадические pattern): уменьшает объём данных, который нужно передать.

10. Заключение

Decode stage плохо батчится, потому что он memory-bound и даже при continuous batching основным лимитом является пропускная способность HBM. Static batching усугубляет ситуацию за счёт синхронизации по самому длинному запросу. Continuous batching — обязательный стандарт для inference серверов, но он не превращает decode в compute-bound. Чтобы добиться масштабирования, инженеры комбинируют микро-архитектурные трюки (GQA, speculative decoding) и системные оптимизации (PagedAttention, квантование KV cache).


Пет-проект для закрепления

Задача: написать микро-бенчмарк на Python с использованием PyTorch и тритонов/FlashAttention для сравнения static и continuous batching на маленькой модели (типа GPT-2 small).

Инструменты: PyTorch, Hugging Face Transformers, vLLM (можно установить как библиотеку), time, nvidia-smi мониторинг.

Шаги:

  1. Загрузить модель GPT-2 (или TinyLlama 1.1B).
  2. Сгенерировать 100 запросов с разной длиной (от 16 до 512 токенов). Использовать одинаковое количество запросов для static и continuous.
  3. Реализовать static batching: сгруппировать запросы в батчи фиксированного размера, выполнять генерацию, замерить total time и throughput (токенов/сек).
  4. Использовать vLLM (под капотом continuous batching) для тех же запросов, замерить те же метрики.
  5. Построить графики зависимости throughput от batch size (1,2,4,8,16,32,64) и длины контекста.
  6. Ожидаемый результат: vLLM покажет в 2-4 раза higher throughput при batch > 8; static batching будет страдать от «хвостов» — падение throughput при увеличении разброса длины.

Дополнительно: включить профилирование через PyTorch profiler, чтобы увидеть, где тратится время (memory copy vs. compute).


Связь с другими вопросами

ВопросТема
436Что такое continuous batching и как оно реализовано?
438Объясните PagedAttention и его роль в оптимизации KV cache.
401В чём разница между prefill и decode stages?
405Как Speculative Decoding улучшает время генерации?
410Какие метрики latency/throughput критичны для LLM inference?
420Зачем нужна Multi-Query Attention (MQA)?

Навигация