Как работает packing для variable-length sequences в FSDP?

Краткий тезис

Packing — это техника объединения нескольких коротких последовательностей разной длины в одну «упакованную» последовательность с помощью attention mask, чтобы избежать неэффективного паддинга. В контексте FSDP (Data Parallelism) packing позволяет значительно увеличить throughput (на 30–50%) за счёт более полного использования GPU-памяти и вычислительных ресурсов, особенно при обучении на данных с переменной длиной (диалоги, документы). FSDP поддерживает packing через внутренние механизмы _set_padded_sequence и _set_sequence_lengths, которые корректно обрабатывают градиенты и коммуникацию между шардами.


1. Проблема variable-length sequences в обучении

При обучении моделей на естественном языке последовательности (предложения, диалоги, документы) имеют разную длину. Стандартный подход — padding (дополнение до максимальной длины в батче токенами <pad>). Это приводит к:

  • Wasted computation — модель обрабатывает пустые токены, тратя FLOPs впустую.
  • Неэффективному использованию памяти — паддинги занимают место в тензорах, увеличивая размер батча искусственно.
  • Проблемам с attention — паддинги должны быть замаскированы (attention mask), что добавляет overhead.

Особенно остро это проявляется в распределённом обучении с FSDP, где каждый GPU хранит только часть параметров, но паддинги всё равно передаются и обрабатываются.


2. Что такое packing

Packing (упаковка) — техника, при которой несколько коротких последовательностей «склеиваются» в одну длинную последовательность, разделённую специальными токенами-разделителями (например, <sep> или <eos>). Для каждой исходной последовательности в упаковке создаётся attention mask, который запрещает токенам из разных исходных последовательностей «видеть» друг друга.

Пример:

  • Исходные последовательности: [A, B, C] (длина 3), [D, E] (длина 2).
  • После паддинга до длины 3: [A, B, C], [D, E, <pad>] — 6 токенов, 1 паддинг.
  • После packing: [A, B, C, <sep>, D, E] (длина 6) — все токены полезные, паддингов нет.

Attention mask для упакованной последовательности будет иметь блочно-диагональную структуру: токены из первой последовательности видят только друг друга, из второй — только друг друга.


3. Как packing работает в контексте FSDP

FSDP (Fully Sharded Data Parallelism) — техника распределённого обучения, при которой параметры, градиенты и оптимизаторные состояния модели шардируются (разбиваются) между GPU. Каждый GPU хранит только часть полной модели, но обрабатывает полный батч данных (data parallelism).

Packing в FSDP работает на уровне DataLoader и collate function:

  1. Сортировка и бакетизация — последовательности сортируются по длине и группируются в бакеты примерно одинаковой суммарной длины.
  2. Упаковка — внутри каждого бакета последовательности пакуются в одну или несколько упакованных последовательностей, стараясь минимизировать паддинг.
  3. Создание attention mask — для каждой упакованной последовательности генерируется маска, которая блокирует кросс-внимание между разными исходными последовательностями.
  4. Передача в модель — упакованные последовательности и маски передаются в модель. FSDP обрабатывает их как обычные тензоры, но благодаря маске градиенты вычисляются корректно для каждой исходной последовательности независимо.

FSDP поддерживает packing через внутренние методы:

  • _set_padded_sequence — помечает тензор как упакованный (содержит padding внутри упаковки).
  • _set_sequence_lengths — хранит длины исходных последовательностей для корректного обратного прохода.

Эти методы гарантируют, что при backward pass градиенты будут агрегированы только внутри каждой исходной последовательности, а не по всей упаковке.


4. Детали реализации packing

4.1 Сортировка и бакетизация

def sort_and_bucket(sequences, max_bucket_tokens=2048):
    # Сортируем по длине (убывание)
    sequences = sorted(sequences, key=len, reverse=True)
    buckets = []
    current_bucket = []
    current_len = 0
    for seq in sequences:
        if current_len + len(seq) <= max_bucket_tokens:
            current_bucket.append(seq)
            current_len += len(seq)
        else:
            buckets.append(current_bucket)
            current_bucket = [seq]
            current_len = len(seq)
    if current_bucket:
        buckets.append(current_bucket)
    return buckets

4.2 Упаковка последовательностей

def pack_sequences(bucket, sep_token_id=1, pad_token_id=0):
    packed = []
    attention_mask = []
    lengths = []
    for seq in bucket:
        packed.extend(seq)
        packed.append(sep_token_id)  # разделитель
        lengths.append(len(seq))
    # Удаляем последний разделитель (опционально)
    packed = packed[:-1]
    # Создаём attention mask (блочно-диагональная)
    total_len = len(packed)
    mask = torch.zeros(total_len, total_len)
    start = 0
    for l in lengths:
        end = start + l
        mask[start:end, start:end] = 1
        start = end + 1  # +1 для разделителя (если он есть)
    # Если разделители не используются, attention mask строится без них
    return packed, mask, lengths

4.3 Использование в DataLoader с FSDP

from torch.utils.data import DataLoader
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

def collate_fn(batch):
    sequences = [item['input_ids'] for item in batch]
    buckets = sort_and_bucket(sequences)
    packed_batch = []
    masks = []
    for bucket in buckets:
        packed, mask, lengths = pack_sequences(bucket)
        packed_batch.append(packed)
        masks.append(mask)
    # Pad до одинаковой длины внутри батча (если бакетов несколько)
    # или используем packing с одной упаковкой на батч
    return {'input_ids': torch.stack(packed_batch), 'attention_mask': torch.stack(masks)}

model = FSDP(model)
dataloader = DataLoader(dataset, batch_size=8, collate_fn=collate_fn)

5. Преимущества packing в FSDP

АспектБез packing (padding)С packing
Использование GPUДо 40% токенов — паддингиПочти 100% полезные токены
ThroughputБазовый+30–50%
ПамятьБольше из-за паддинговМеньше, можно увеличить batch size
Время обученияДольшеБыстрее за счёт меньшего числа шагов
Attention maskПростая (0/1)Блочно-диагональная, сложнее

6. Trade-offs и подводные камни

  • Сложность реализации — требуется custom collate_fn и корректная обработка attention mask.
  • Overhead на сортировку — сортировка последовательностей перед каждым батчем может замедлить DataLoader, но это компенсируется ускорением обучения.
  • Проблемы с convergence — если в одну упаковку попадают последовательности из разных доменов, модель может путать контексты. Рекомендуется группировать семантически близкие последовательности.
  • Совместимость с FSDP — не все реализации FSDP корректно обрабатывают упакованные тензоры. Необходимо использовать версию PyTorch >= 2.0 и проверять поддержку _set_padded_sequence.
  • Ограничение на максимальную длину — упакованная последовательность не должна превышать max_position_embeddings модели.

7. Сравнение с альтернативами

МетодОписаниеПлюсыМинусы
PaddingДополнение до макс. длины в батчеПростотаНизкая эффективность
Dynamic batchingГруппировка последовательностей одинаковой длиныУмеренная эффективностьТребует сортировки, не идеально
PackingОбъединение нескольких последовательностейМаксимальная эффективностьСложность, overhead на маску
Gradient accumulationНесколько маленьких батчейПростотаУвеличивает время обучения

8. Когда использовать packing

Packing особенно эффективен в сценариях с высокой вариативностью длины последовательностей:

  • Диалоговые данные — сообщения разной длины.
  • Документы — от коротких заметок до длинных статей.
  • Мультимодальные данные — текст + изображения разного размера.
  • Fine-tuning на смешанных датасетах — где длина сильно варьируется.

Если данные уже имеют примерно одинаковую длину (например, стандартизированные логи), выигрыш от packing будет минимальным.


Пет-проект для закрепления

Задача Реализовать packing для fine-tuning небольшой LLM (например, GPT-2) на датасете диалогов (DailyDialog) с использованием FSDP на 2 GPU.

Инструменты PyTorch, Hugging Face Transformers, FSDP (PyTorch), Datasets.

Шаги:

  1. Загрузить датасет DailyDialog и токенизировать диалоги.
  2. Написать collate_fn с сортировкой, бакетизацией и упаковкой (как в разделе 4).
  3. Обучить модель GPT-2 с FSDP на упакованных данных.
  4. Сравнить throughput (токенов/сек) и loss с baseline (padding) при одинаковом batch size.
  5. Построить график зависимости throughput от степени вариативности длины (симулировать разные распределения).

Ожидаемый результат Ускорение обучения на 30–50% при сохранении качества (loss не должен отличаться более чем на 0.05). Код с комментариями и визуализацией.


Связь с другими вопросами

ВопросТема
475Как работает FSDP и чем отличается от DDP?
477Какие стратегии шардирования параметров существуют в FSDP?
480Как оптимизировать throughput при distributed training?
482Что такое gradient checkpointing и как его использовать с FSDP?
485Как работает attention mask в transformer-моделях?
490Какие методы борьбы с паддингом вы знаете?

Навигация