Как вы делаете fine-tuning на последовательностях разной длины (packing, dynamic batching)?
Краткий тезис
При fine-tuning LLM на естественных данных длины последовательностей почти всегда различаются. Прямое использование padding до максимальной длины в батче ведёт к потере памяти и времени на обработку пустых токенов. Два основных подхода — packing (объединение нескольких коротких примеров в одну последовательность) и dynamic batching (группировка батчей по схожей длине) — существенно повышают эффективность обучения. Современные реализации (например, FlashAttention) поддерживают переменные длины без потери производительности, что делает packing особенно удобным.
-------|----------| | Разбазаривание памяти | В батче из коротких и длинных примеров большая часть токенов — padding, особенно при высоком разбросе длин. | | Лишние вычисления | Модель (особенно Transformer) выполняет операции со всеми токенами, включая паддинг; softmax в attention тратит ресурсы на маскированные позиции. | | Ограничение batch size | Из-за padding батч занимает больше видеопамяти, что вынуждает уменьшать batch size и замедляет сходимость. |
1.3. Когда оправдано
- Очень короткие фиксированные входы (например, классификация коротких текстов).
- Простота реализации в фреймворках (PyTorch
DataLoaderсcollate_fnпо умолчанию).
2. Packing: объединение коротких последовательностей
2.1. Идея
Несколько исходных примеров конкатенируются в одну последовательность (до максимальной длины контекста модели). Каждый пример отделяется специальным токеном конца последовательности (e.g. <|endoftext|> или [SEP]) или используется известный формат (например, при обучении GPT-подобных моделей). Модель обучается на всей последовательности, при этом для каждого подсегмента внутри последовательности используется attention mask, запрещающий взаимодействие между примерами.
2.2. Реализация
# Псевдокод packing для GPT-2 (с Hugging Face)
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
texts = ["Hello world", "How are you?", "Fine-tuning LLMs"]
# Токенизируем каждый пример, добавляем eos_token между ними
tokens = [tokenizer.encode(t) + [tokenizer.eos_token_id] for t in texts]
flat_tokens = sum(tokens, []) # просто конкатенация
# Создаём attention mask: блокируем взаимодействие между примерами
# (здесь упрощённо — для каждого подсегмента свой 1, а между подсегментами 0)
# На практике можно использовать пакетные версии или кастомный коллейтер.
Полноценный packing требует:
- Position IDs — для каждого подсегмента свои позиции (начиная с 0) или глобальные, но с маскировкой.
- Loss calculation — обычно loss считается только на токенах входа (при causal LM) или на всех позициях, кроме padding/packing-разделителей.
- Внимание: если между подсегментами разрешить взаимодействие, модель научится «подглядывать» в соседние примеры — это нежелательно. Поэтому нужна block diagonal attention mask.
2.3. Преимущества
- Экономия GPU памяти: нет пустых padding-токенов, все вычисления идут только над реальными данными.
- Увеличивается batch size в смысле количества обрабатываемых примеров (за счёт объединения).
- Пропускная способность: больше токенов за шаг обучения при той же вычислительной мощности.
2.4. Недостатки
- Сложность реализации корректной маски внимания и position ids.
- При коротких примерах и большом их количестве packing может генерировать очень длинные последовательности, требующие больше памяти на один элемент батча (но суммарная эффективность обычно выше).
- Не подходит для задач, где каждый пример должен предсказывать свой собственный ответ независимо (например, QA с разными контекстами).
3. Dynamic batching: группировка по длине
3.1. Концепция
Батчи формируются не случайно, а так, чтобы внутри одного батча длины последовательностей были максимально близки. Это позволяет использовать padding с минимальным объёмом лишних токенов. Длины можно распределить по «корзинам» (buckets), например: [0-50], [51-100], [101-200] и т.д.
3.2. Реализация
# Hugging Face Data Collator для динамического батчинга
from transformers import DataCollatorWithPadding
import numpy as np
# Сначала отсортировать данные по длине (или использовать bucket)
# Затем при колляции:
collator = DataCollatorWithPadding(tokenizer, padding="longest")
# Внутри он берёт самый длинный пример в текущем батче и паддит все остальные до его длины.
# Если батчи заранее сгруппированы по длине, "longest" будет близок к средней длине батча.
Более продвинутый вариант — bucket‑based batching: алгоритм группирует примеры по длине в N корзин, затем внутри каждой корзины перемешивает и формирует батчи. Пример с torchdata или datasets:
from datasets import Dataset, load_dataset
from tokenizers import Tokenizer
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
tokenizer = Tokenizer.from_pretrained("bert-base-uncased")
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=False)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# Сортируем по длине
sorted_dataset = tokenized_dataset.sort("length")
# Формируем корзины по 1000 примеров (bucket size)
bucket_size = 1000
buckets = np.array_split(sorted_dataset, len(sorted_dataset) // bucket_size)
# Внутри каждой корзины создаём DataLoader с DataCollatorWithPadding
# Это даёт низкий разброс длин внутри батча.
3.3. Плюсы и минусы
| Характеристика | Плюсы | Минусы |
|---|---|---|
| Простота | Реализуется через стандартный collator + сортировка | Требует предварительной сортировки (может нарушить случайность) |
| Эффективность памяти | Padding почти равен средней длине батча | Не устраняет padding полностью |
| Дополнительно | Легко комбинируется с другими оптимизациями (Gradient Accumulation) | При очень неравномерном распределении длин всё равно могут быть «хвосты» |
3.4. Вариация: Dynamic Padding в PyTorch
Можно написать собственный collate_fn, который динамически определяет максимальную длину в батче и паддит до неё — это минимальный уровень динамического батчинга. Однако без группировки по длине эффект будет слабый.
4. FlashAttention и поддержка packing
4.1. Что такое FlashAttention
FlashAttention — алгоритм, который вычисляет точное attention без материализации полной матрицы S = Q * K^T, используя кумулятивную обработку блоков и тилинг для эффективного использования SRAM на GPU. Это позволяет:
- Значительно ускорить forward/backward (до 2-4x).
- Снизить потребление памяти (память $O(n)$ вместо $O(n^2)$ для одного слоя).
4.2. Как FlashAttention помогает packing
При packing часто требуется block diagonal mask, чтобы внутри одного объединённого примера подсегменты не видели друг друга. Обычная реализация attention mask для такой маски — большая разреженная матрица, что неэффективно.
FlashAttention (начиная с версии 2) поддерживает custom attention mask через указание блоков: можно передать mask в виде тензора, который определяет, какие пары токенов должны быть замаскированы. Для packing естественно использовать маску, где каждый подсегмент внутри последовательности имеет свою диагональ, а между подсегментами — -inf.
# Пример конфигурации в Hugging Face с FlashAttention-2:
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b",
attn_implementation="flash_attention_2"
)
# При forward передаём attention_mask с block-diagonal структурой
Таким образом, packing с FlashAttention:
- Не требует дополнительного копирования данных для создания отдельных батчей.
- Обрабатывает блоки внимания последовательно, что снижает пиковое использование памяти.
- Позволяет использовать единое пространство позиций (sequential position ids) без сброса, так как маска заботится о разделении.
4.3. Практические соображения
- Поддержка моделей: FlashAttention доступен для LLaMA, Mistral, Falcon, GPT-NeoX и многих других (через Hugging Face или ядра xformers).
- Ограничения: FlashAttention 2 требует GPU с архитектурой Ampere (A100, A6000, RTX 3090/4090) или новее. Для старых карт может использоваться FlashAttention 1 или обычные реализации.
- Packing + FlashAttention — стандарт для fine-tuning больших моделей (например, в библиотеках MegaBlocks, RecurrentGemma, и в собственных решениях крупных компаний).
5. Пет-проект для закрепления
Задача: Написать скрипт, который fine-tunes модель distilgpt2 на датасете wikitext-2 с использованием packing и dynamic batching по отдельности, и сравнить скорость обучения (tokens/sec) и использование GPU памяти.
Инструменты
- Python 3.10+
- PyTorch 2.0+ с CUDA (или
mpsдля Apple Silicon) - Transformers, Datasets, Accelerate
- GPU с хотя бы 6GB VRAM (подойдёт RTX 3060)
- Библиотека
nvidia-smiилиtorch.cuda.memory_summary()для мониторинга памяти
Шаги
- Загрузите датасет и токенизируйте (добавьте
eos_token). - Реализуйте DataCollator с padding (базовый baseline).
- Реализуйте DataCollator с dynamic batching: отсортируйте данные по длине, разбейте на батчи одинаковой длины (например, батч из 16 примеров, но все с длиной ±10 токенов).
- Реализуйте DataCollator с packing: напишите функцию, которая конкатенирует несколько примеров до максимальной длины контекста (пусть 512), создаёт block-diagonal mask.
- Для каждого варианта:
- Запустите один шаг обучения (unfreeze все параметры).
- Замерьте время шага, количество токенов в батче и пиковое потребление памяти.
- Повторите 10 раз и усредните.
- Постройте таблицу сравнения (tokens/sec, память, loss за шаг).
Ожидаемый результат
- Dynamic batching ускорит обучение в ~1.5-2 раза по сравнению с обычным padding (переменная длина без сортировки).
- Packing даст дополнительный прирост ~20-30% за счёт полного отсутствия padding-токенов.
- FlashAttention (если поддерживается) добавит ещё 2-3x по скорости, но может не пригодиться на distilgpt2 (реализуйте для любой каузальной LM с поддержкой FlashAttention 2).
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 467 | Как выбрать batch size при fine-tuning LLM? |
Вопрос 467 обсуждает влияние batch size на сходимость и память, а packing/dynamic batching напрямую влияют на эффективный batch size (число обрабатываемых примеров).
Навигация
- Предыдущий: 975
- Следующий: 977
- Индекс: 00. Индекс разборов