English translation is not available yet. Showing Russian content.
Настроить гибридную архитектуру Mamba + Attention для улучшения качества языковой модели
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить гибридную архитектуру Mamba + Attention для улучшения качества языковой модели
1. Цель задачи
Создать и обучить гибридную архитектуру, которая использует механизм самовнимания (Attention) для моделирования глобальных зависимостей (дальние токены) и Mamba (State Space Model) для эффективного локального контекста. Требуется сравнить качество гибридной модели с чистым трансформером одинаковой вычислительной ёмкости на задаче языкового моделирования (perplexity) и, опционально, на downstream задачах (например, логическое рассуждение на CLUTRR или QA).
Ключевой результат Гибридная модель (Mamba + Attention) демонстрирует лучший perplexity на тестовом наборе (WikiText-103) по сравнению с чистым трансформером при равном или меньшем числе параметров и / или времени инференса; метрики качества (exact match / F1) на задаче синтеза длинных текстов не ниже baseline.
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Фреймворк глубокого обучения | PyTorch 2.x (установить через pip) |
| Реализация слоя Mamba | Репозиторий mamba-ssm (pip install mamba-ssm) или самописная реализация (см. раздел симуляции) |
| Базовая модель Transformer | HuggingFace transformers – GPT-2 small (124M) как baseline |
| Датасет для языкового моделирования | WikiText-103 (HuggingFace datasets) – train/validation/test |
| Датасет для downstream оценки (опционально) | CLUTRR (логические цепочки) или LongBench (обобщение длинных текстов) |
| Средство логирования | Weights & Biases (wandb) или MLflow – завести проект mamba_attention_hybrid |
| Хранилище результатов | Локальная папка ./experiments/ с чекпоинтами и метриками |
Если нет реального инструмента — симулируем:
-
Если mamba-ssm не устанавливается (например, из-за CUDA): реализуем упрощённый Mamba-блок на чистом PyTorch.
- Использовать готовый код из репозитория mamba (файл
mamba_ssm/modules/mamba_simple.py) – скопировать какmamba_block.py. - Заменить cuda-зависимые операции (например, selective scan) на torch.einsum + сканирование через torch.cumsum (потеря скорости, но сохранение функциональности).
- Проверить корректность на случайном тензоре длины 512.
- Использовать готовый код из репозитория mamba (файл
-
Если нет GPU: все запуски проводить на CPU с reduced batch size (1-2) и ограниченной длиной (128 токенов), обучение 1-2 эпохи только для проверки сходимости.
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Фреймворк | PyTorch 2.2 | Обучение и инференс |
| Mamba-реализация | mamba-ssm (v1.2+) или самописная (sim_mamba) | SSP-слой для локального контекста |
| Трансформер-блоки | HuggingFace Transformers (GPT-2Block) | Attention-слои для глобального контекста |
| Датасеты | datasets (WikiText-103) | Тренировка и валидация |
| Токенизатор | GPT-2 tokenizer (HuggingFace) | Токенизация текста |
| Логирование | Weights & Biases (wandb) | Отслеживание метрик, лосса, perplexity |
| Гиперпараметры | Hydra / argparse | Управление конфигурацией |
| Среда | Linux (Ubuntu 22.04), Python 3.10 | Выполнение экспериментов |
| Метрики | Perplexity (скользящая), loss, speed (tokens/sec) | Качество и эффективность |
4. Этапы выполнения
Этап 1: Подготовка окружения и данных (2-3 часа)
Действия
-
Создать виртуальное окружение и установить зависимости:
python -m venv venv source venv/bin/activate pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 pip install transformers datasets wandb mamba-ssm hydra-coreЕсли mamba-ssm не ставится – установить самописную реализацию из
mamba_block.py(без cuda-оптимизаций). -
Загрузить датасет и токенизатор
from datasets import load_dataset from transformers import GPT2Tokenizer dataset = load_dataset("wikitext", "wikitext-103-raw-v1") tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token -
Подготовить DataLoader
- Объединить все тексты в один длинный поток (concatenate с
\n). - Разделить на чанки по block_size=1024 токенов.
- Создать Dataset с перекрытием (stride=512) для обучения.
- Базовая конфигурация: batch_size=4 (GPU) / 1 (CPU), num_workers=2.
- Объединить все тексты в один длинный поток (concatenate с
-
Записать baseline (чистый трансформер):
- Взять GPT-2 small из HuggingFace: model = GPT2LMHeadModel.from_pretrained("gpt2").
- Обучить 1 эпоху на WikiText-103 (или зафиксировать pretrained perplexity на валидации).
Ожидаемый результат этапа Готовый train_loader, val_loader, baseline модель (GPT-2) с измеренным perplexity на валидации (~30-35 для pretrained). Результат записан в baseline_perplexity.json.
Этап 2: Проектирование и реализация гибридной архитектуры (4-6 часов)
Действия
-
Спроектировать конфигурацию (Hydra YAML):
# config.yaml model: vocab_size: 50257 hidden_size: 768 num_layers: 12 hybrid_every: 3 # каждый 3-й слой – Attention, остальные Mamba attn_layers: [3, 6, 9, 12] # список позиций (1-indexed) training: lr: 3e-4 epochs: 5 block_size: 1024 -
Написать класс HybridModel
- Стек из
num_layersблоков: для каждого слоя i, если i вattn_layers– GPT2Block (Attention), иначе – MambaBlock (из mamba-ssm или самописный). - Head на выходе: nn.Linear(hidden_size, vocab_size).
- Внимание: GPT2Block использует causal masking.
- MambaBlock: слой Mamba (размерность
d_model=hidden_size,expand=2).
- Стек из
-
Реализовать самописный Mamba (если нет mamba-ssm):
class SimpleMamba(nn.Module): def __init__(self, d_model, d_state=16, expand=2): super().__init__() # Параметры: A_log, D, dt_proj, in_proj, x_proj, out_proj self.in_proj = nn.Linear(d_model, expand * d_model) self.ssm = SSM(d_model * expand, d_state) # самописный selective scan self.out_proj = nn.Linear(expand * d_model, d_model) def forward(self, x): # x: (B, L, D) residual = x x = self.in_proj(x) x = self.ssm(x) x = self.out_proj(x) return x + residual(Подробную реализацию selective scan см. в attached
ssm_scan.py– предоставить в репозитории). -
Проверить суммарное число параметров
- Гибридная модель должна иметь не больше параметров, чем GPT-2 small (124M). Если больше – уменьшить
hidden_sizeилиexpand.
- Гибридная модель должна иметь не больше параметров, чем GPT-2 small (124M). Если больше – уменьшить
Ожидаемый результат этапа Класс HybridModel (файл model.py), проверка на синтетическом батче – loss не NaN, forward проходит.
Этап 3: Обучение и логирование (6-8 часов на GPU / 20+ часов на CPU)
Действия
-
Интегрировать с
wandb- Инициализация:
wandb.init(project="mamba_attention_hybrid", name="hybrid_v1"). - Логировать: loss (train/val), perplexity, learning rate, tokens_per_sec.
- Инициализация:
-
Написать training loop
- Использовать AdamW с
lr=3e-4,weight_decay=0.1. - Планировщик: cosine decay, warmup steps 500.
- Сохранять чекпоинты каждые 2500 шагов.
- Валидация каждые 500 шагов на 200 батчах.
- Использовать AdamW с
-
Параллельно обучить baseline (если не был предобучен):
- Та же процедура, те же гиперпараметры, модель –
GPT2LMHeadModel(не предобученная, случайная инициализация).
- Та же процедура, те же гиперпараметры, модель –
-
Сравнивать динамику
- Оба эксперимента фиксировать в wandb с тегом
baseline/hybrid. - Замерить perplexity на валидации после 1, 3, 5 эпох.
- Оба эксперимента фиксировать в wandb с тегом
Ожидаемый результат этапа Два набора чекпоинтов (baseline, hybrid) и wandb дашборд с графиками loss и perplexity. Гибридная модель должна показывать меньший perplexity после 3-5 эпох (разница >1 пункт).
Этап 4: Оценка на downstream задаче (2-3 часа)
Действия
-
Выбрать задачу
-
Сгенерировать ответы для обеих моделей (baseline и hybrid) с
temperature=0.7,max_new_tokens=50. -
Оценить exact match (EM) и F1 на тестовой части CLUTRR (200 примеров).
-
Записать метрики и сравнить
- Ожидание: hybrid модель имеет EM выше на 2-5% из-за лучшего моделирования локальных связей (Mamba помогает с цепочками шагов).
Ожидаемый результат этапа Таблица с метриками, подтверждающая преимущество hybrid.
Этап 5: Анализ и документирование (1-2 часа)
Действия
-
Построить графики
- Влияние отношения числа Attention/Mamba слоёв на perplexity (эксперимент с
hybrid_every2,3,4). - Скорость инференса (tokens/sec) при длине последовательности 2048.
- Влияние отношения числа Attention/Mamba слоёв на perplexity (эксперимент с
-
Записать выводы в README
- Какая конфигурация дала лучший perplexity?
- Сколько памяти занимает каждая модель?
- Устойчивость к длинным контекстам (Mamba даёт O(1) память на шаг, Attention O(L^2)).
-
Создать финальный отчёт
- Markdown-файл
hybrid_report.mdс метриками, графиками и рекомендациями.
- Markdown-файл
Ожидаемый результат этапа Отчёт в experiments/hybrid_report.md, чекпоинты лучшей модели в best_model.pt.
5. Критерии приемки (Definition of Done)
- Код гибридной модели реализован, проходит unit-тест на синтетическом батче (loss < 10 без NaN).
- Baseline (чистый трансформер) обучен на тех же данных и гиперпараметрах.
- Гибридная модель достигла perplexity ниже baseline на валидации WikiText-103 (разница >1.0 после 5 эпох).
- На downstream задаче (CLUTRR) hybrid показал EM >= baseline (p-value < 0.05).
- Все эксперименты залогированы в wandb (или TensorBoard) с воспроизводимыми конфигами.
- Собрана статистика скорости инференса: hybrid быстрее baseline при длине > 1024 токенов (или сопоставим).
- Результаты задокументированы в
hybrid_report.mdс таблицей и графиками. - Конфигурации hyperparameters сохранены в
configs/(Hydra).
6. Ожидаемый результат
Основной артефакт Папка experiments/mamba_attention_hybrid/ с содержимым:
| Файл/Папка | Описание |
|---|---|
model.py | Класс HybridModel и вспомогательные блоки (Mamba, Attention) |
train.py | Скрипт обучения с аргументами (Hydra) |
configs/ | Конфигурации для разных вариантов (hybrid_every=2,3,4) |
best_model.pt | Чекпоинт лучшей гибридной модели |
baseline_model.pt | Чекпоинт baseline |
hybrid_report.md | Отчёт с метриками, графиками, выводами |
logs/ | Логи wandb / tensorboard |
metric_table.csv | Perplexity и EM для всех конфигураций |
Опциональные результаты Демо-ноутбук с инференсом (загрузка модели, генерация текста).
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| Mamba-ssm не устанавливается (отсутствие CUDA) | Реализовать упрощённый Mamba с torch.einsum и cumsum – потеря скорости, но функциональность сохранена. |
| Нестабильное обучение (loss NaN) | Уменьшить learning rate (1e-4), добавить gradient clipping (max_norm=1.0), проверить нормализацию (LayerNorm после каждого блока). |
| Гибридная модель больше памяти | Уменьшить hidden_size до 512 или expand=1; использовать gradient checkpointing (torch.utils.checkpoint). |
| Отсутствие downstream набора CLUTRR | Заменить на любой QA (например, BoolQ) или синтетическую задачу (reverse строки). |
| Недостаточно времени на CPU обучение | Ограничить тренировочные данные 1% датасета, 1 эпоху, оценить только динамику loss. |
| Сравнение с pretrained GPT-2 нечестно | Обучать baseline с нуля на тех же данных (random init). |
8. Бюджет времени (оценка)
| Этап | Время (часы) |
|---|---|
| Этап 1: Подготовка окружения и данных | 2-3 |
| Этап 2: Реализация гибридной архитектуры | 4-6 |
| Этап 3: Обучение и логирование | 6-8 (GPU) / 20+ (CPU) |
| Этап 4: Оценка на downstream | 2-3 |
| Этап 5: Анализ и документирование | 1-2 |
| Итого | 15-22 (GPU) / 30+ (CPU) |
Примечание для первого раза Добавьте резерв 50% времени на наладку Mamba и отладку стабильности.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 42 | Архитектура трансформера (Attention, positional encoding) |
| 78 | State Space Models: теория (S4, Mamba) |
| 120 | Selective scan и параллельные префиксные суммы |
| 145 | Сравнение Transformer vs SSM по сложности |
| 200 | Оценка языковых моделей (perplexity, BPC) |
| 305 | Long-context transformers (Longformer, BigBird) |
| 410 | Гибридные архитектуры (MambaFormer, Jamba) |
| 520 | Настройка гиперпараметров для LM (learning rate schedule) |
| 610 | Downstream evaluation (GLUE, SuperGLUE) |
| 755 | Gradient checkpointing и снижение памяти |
10. Чек-лист самопроверки
- Я реализовал гибридную модель, которая комбинирует Mamba и Attention, и она корректно работает на синтетическом тесте.
- Я обучил baseline (чистый трансформер) с нуля с теми же гиперпараметрами.
- Я сравнил perplexity на валидации и убедился, что hybrid лучше (разница >1).
- Я провёл downstream оценку и получил метрики выше или равные baseline.
- Я записал конфигурацию эксперимента, лосс и параметры в wandb / отчёт.
- Я проверил, что скорость инференса не сильно упала (или улучшилась на длинных последовательностях).
- Я подготовил README с выводами и воспроизводимыми командами.