Реализовать diffusion LLM (PLANNER)
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Реализовать diffusion LLM (PLANNER)
1. Цель задачи
Разработать и обучить упрощённую версию diffusion language model с планировщиком (PLANNER), который позволяет генерировать несколько токенов за один шаг денойзинга. В отличие от классических авторегрессивных LLM, diffusion-подход итеративно превращает случайный шум в текст, а PLANNER предсказывает сразу несколько позиций (слов) в предложении, ускоряя генерацию коротких ответов.
Ключевой результат Рабочая модель, которая на коротких ответах (длиной до 128 токенов) генерирует текст в 2 раза быстрее (по latency) по сравнению с baseline-авторегрессивной моделью той же архитектуры, с сопоставимым качеством (perplexity не выше baseline на 15%).
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Предобученная BERT-base (или аналогичный encoder) | HuggingFace google-bert/bert-base-uncased |
| Датасет для обучения / валидации | OpenWebText, wiki-text-2 или TinyStories (рекомендуется для быстрого обучения) |
| Скрипты для измерения latency | Написать самостоятельно (time.perf_counter) |
| Baseline-модель для сравнения | Та же BERT-base, но с авторегрессивной генерацией (например, через causal LM head) |
Если нет возможности обучить на большом датасете — симулируем:
- Используем TinyStories (HuggingFace: roneneldan/TinyStories) — 2–3 эпохи обучения на подвыборке 50K примеров.
- Baseline — это та же модель с заменой diffusion head на causal LM head, обученная на том же датасете (можно с тем же числом шагов).
- Для замера ускорения используем один GPU (T4 или аналоги) и фиксированный batch_size=1, измеряем среднее время генерации 10 коротких запросов (длина ответа 32, 64, 128 токенов).
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Фреймворк | PyTorch 2.x + HuggingFace Transformers | Загрузка модели, обучение, инференс |
| Диффузионный модуль | Собственная реализация на PyTorch | Noising/sampling (дискретный diffusion) |
| Планировщик (PLANNER) | PyTorch nn.Module | Предсказание нескольких позиций за шаг |
| Мониторинг обучения | Weights & Biases или TensorBoard | Отслеживание loss, perplexity, ускорения |
| Генерация текста | Python + собственный инференс-скрипт | Замер latency и качества |
| Визуализация | matplotlib / seaborn | Графики "ускорение vs длина ответа" |
4. Этапы выполнения
Этап 1: Исследование и архитектура (2 часа)
Действия
-
Изучить основные работы по дискретному diffusion для текста (D3PM, MDLM, DiffusionBERT). Выбрать простой подход: multinomial diffusion (каждый токен — one-hot вектор, шум — uniform по словарю).
-
Определить формат PLANNER: вместо предсказания всего предложения за один шаг — модель предсказывает маску для нескольких позиций (например, 25% токенов за шаг). На каждом шаге денойзинга планировщик выбирает случайное подмножество позиций и обновляет их, остальные остаются зашумлёнными.
-
Зафиксировать гиперпараметры:
Ожидаемый результат этапа Документ с обоснованием выбора архитектуры, диаграмма процесса генерации (шум → итеративное денойзинг с PLANNER). Код пока не пишется.
Этап 2: Реализация процесса diffusion (3 часа)
Действия
-
Реализовать
forward_noising— превращение чистого текста в последовательность шумовых векторов за T шагов (диффузионный процесс):- На входе: one-hot токены (batch, seq_len, vocab_size)
- На шаге t: смешиваем исходный вектор с uniform-шумом с интенсивностью β(t) (cosine schedule)
- Выход: зашумлённое распределение на каждом шаге t.
-
Реализовать
reverse_sampling— обратный процесс:- Начинаем с чистого шума (uniform) на шаге T.
- Для t от T-1 до 0:
- Модель получает текущее зашумлённое состояние.
- PLANNER выбирает маску (какие позиции обновлять).
- Обновляем только выбранные позиции согласно предсказанию модели.
- Возвращаем финальный текст (argmax по последнему шагу).
-
Написать тестовая проверка на одном предложении: зашумление → денойзинг → проверка, что восстановленный текст близок к исходному.
Ожидаемый результат этапа Модули diffusion_utils.py с функциями add_noise и sample_step. Проверочный скрипт, который показывает, что при T=32 и правильных предсказаниях текст восстанавливается.
Этап 3: Реализация PLANNER (4 часа)
Действия
-
Добавить к BERT дополнительную голову — планировщик:
- Это MLP + линейный слой, который из скрытых состояний BERT предсказывает маску выбора (binary mask, sigmoid) + новые токены (softmax по словарю).
- На вход: скрытые состояния после последнего слоя BERT (batch, seq_len, hidden_size).
- На выход:
logits_mask(batch, seq_len, 1) иlogits_token(batch, seq_len, vocab_size).
-
Реализовать loss-функцию:
- Для каждой позиции, если она выбрана маской (ground truth маска — те позиции, которые мы знаем, что нужно обновить), применяется cross-entropy loss к предсказанному токену.
- Если не выбрана — loss только за то, что маска предсказала не обновлять (BCE).
- Суммарный loss:
L_total = L_ce + λ * L_mask, где λ=0.1.
-
Имплементировать предиктор маски во время обучения: сначала модель предсказывает маску, затем мы применяем её к ground truth токенам (через teacher forcing) — на каждом шаге мы знаем, какие позиции должны быть обновлены (например, самые зашумлённые).
Ожидаемый результат этапа Класс DiffusionBERT с forward методом, который принимает (input_ids, noise_level_t) и возвращает (logits_mask, logits_token). Проход обучения для одного батча работает без ошибок.
Этап 4: Обучение и оценка (5 часов)
Действия
-
Подготовить датасет:
- Загрузить TinyStories, отфильтровать примеры длиной ≤128 токенов.
- Токенизировать BERT-токенизатором (add_special_tokens=True).
- Создать DataLoader с padding до max_length и attention mask.
-
Реализовать цикл обучения (1000 шагов, валидация каждые 100 шагов):
- Для каждого батча:
- Выбрать случайный шаг t (0..T-1).
- Применить шум к batch (get noisy_embedding, mask_gt — какие токены исходные, какие шумовые).
- Прогнать модель, получить потери.
- Backprop.
- На валидации: измерять perplexity на зашумлённых данных (по предсказанию токенов).
- Для каждого батча:
-
После обучения зафиксировать модель и сохранить чекпоинт.
Ожидаемый результат этапа Обученная модель diffusion_bert_checkpoint.pt. График loss и perplexity на валидации.
Этап 5: Инференс и замер ускорения (3 часа)
Действия
-
Реализовать функцию
generate_diffusion(prompt, max_length=128):- prompt токенизируется и паддится до max_length.
- Инициализируем шум (uniform) по всем позициям.
- Запускаем обратный процесс с PLANNER (T=32 шага).
- Замеряем время генерации.
-
Реализовать baseline — авторегрессивная генерация с той же BERT-base (causal LM head):
- Использовать
transformers.AutoModelForCausalLM(можно дообучить на том же датасете). - Замеряем время для такой же длины ответа.
- Использовать
-
Сравнить latency для трёх длин: 32, 64, 128 токенов (по 20 повторений на каждую).
- Вычислить среднее время и ускорение (time_baseline / time_diffusion).
-
Оценить качество сгенерированных ответов:
- Perplexity по модели (встроенный loss).
- BLEU-1, BLEU-2 на 100 примерах (если есть референсные ответы).
- Ручная оценка осмысленности 10 ответов.
Ожидаемый результат этапа Таблица ускорения, графики, сгенерированные примеры. Ускорение 2x на short ответах — подтверждено.
5. Критерии приемки (Definition of Done)
- Реализован полный пайплайн: noising, sampling, обучение, инференс.
- Обученная diffusion-модель способна генерировать текст длиной до 128 токенов.
- Замеры latency показывают ускорение ≥2× для ответов 32-64 токена (сравнение с baseline).
- Perplexity на валидации не превышает baseline более чем на 15%.
- Код воспроизводим: README с инструкциями, requirements.txt, все скрипты в папке.
- Сгенерированные тексты визуально осмысленны (не полностью случайны).
- Есть отчёт в виде Jupyter Notebook с графиками и выводами.
6. Ожидаемый результат
Файлы и артефакты
| Артефакт | Описание |
|---|---|
src/ | Папка с кодом (diffusion_utils.py, model.py, train.py, generate.py) |
checkpoints/diffusion_bert_checkpoint.pt | Обученная модель |
results/latency_comparison.csv | Таблица latency baseline vs diffusion |
results/generated_samples.txt | Примеры сгенерированных текстов |
report.ipynb | Jupyter Notebook с анализом, графиками, выводами |
requirements.txt | Зависимости |
Дополнительные результаты (опционально):
- Демо в Streamlit для интерактивной генерации.
- Сравнение с другими diffusion-подходами (MDLM, D3PM).
- Анализ влияния числа шагов T на скорость/качество.
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| Долгое обучение (на GPU без поддержки) | Использовать TinyStories (маленький датасет), уменьшить T до 16, batch size 8 |
| Нестабильный loss (особенно маска) | Увеличить λ до 0.5, добавить gradient clipping (max_norm=1.0) |
| Низкое качество генерации | Включить classifier-free guidance (CFG) со scale=1.5 |
| Ускорение меньше 2x | Увеличить долю обновляемых позиций до 0.5, уменьшить T до 8, использовать FP16 |
| Проблема с attention mask | BERT требует правильной маски при паддинге — передавать attention_mask в модель |
8. Бюджет времени (оценка)
| Этап | Время (часы) |
|---|---|
| Этап 1: Исследование и архитектура | 2 |
| Этап 2: Реализация процесса diffusion | 3 |
| Этап 3: Реализация PLANNER | 4 |
| Этап 4: Обучение и оценка | 5 |
| Этап 5: Инференс и замер ускорения | 3 |
| Итого | 17 часов |
Примечание: Время указано для опытного ML-инженера. Первый раз может потребоваться +50% времени на отладку.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 101 | Дискретный diffusion для текста |
| 105 | Multinomial diffusion D3PM |
| 112 | Архитектура BERT и её модификации |
| 207 | Attention masking в трансформерах |
| 304 | Cosine noise schedule |
| 408 | Ускорение инференса LLM (специализированные техники) |
| 512 | Модульное тестирование пайплайнов генерации |
| 623 | Classifier-free guidance (CFG) |
| 734 | Сравнение авторегрессивных и diffusion-моделей |
| 851 | Практики замеров latency в NLP |
10. Чек-лист самопроверки
- Я реализовал оба процесса (noising и sampling) и проверил на одном примере, что текст восстанавливается (хотя бы частично).
- Я обучил модель хотя бы 500 шагов и убедился, что loss уменьшается.
- Я измерил latency baseline и diffusion на одинаковых аппаратных условиях (GPU, batch_size=1).
- Я сравнил качество через perplexity на валидации — разница ≤15%.
- Я написал README с инструкцией по запуску и требованием к среде.
- Я зафиксировал seed (например, 42) для воспроизводимости.