English translation is not available yet. Showing Russian content.

Настроить гибридную архитектуру Mamba + Attention для улучшения качества языковой модели

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Настроить гибридную архитектуру Mamba + Attention для улучшения качества языковой модели

1. Цель задачи

Создать и обучить гибридную архитектуру, которая использует механизм самовнимания (Attention) для моделирования глобальных зависимостей (дальние токены) и Mamba (State Space Model) для эффективного локального контекста. Требуется сравнить качество гибридной модели с чистым трансформером одинаковой вычислительной ёмкости на задаче языкового моделирования (perplexity) и, опционально, на downstream задачах (например, логическое рассуждение на CLUTRR или QA).

Ключевой результат Гибридная модель (Mamba + Attention) демонстрирует лучший perplexity на тестовом наборе (WikiText-103) по сравнению с чистым трансформером при равном или меньшем числе параметров и / или времени инференса; метрики качества (exact match / F1) на задаче синтеза длинных текстов не ниже baseline.


2. Исходные данные

Что нужноОткуда взять
Фреймворк глубокого обученияPyTorch 2.x (установить через pip)
Реализация слоя MambaРепозиторий mamba-ssm (pip install mamba-ssm) или самописная реализация (см. раздел симуляции)
Базовая модель TransformerHuggingFace transformersGPT-2 small (124M) как baseline
Датасет для языкового моделированияWikiText-103 (HuggingFace datasets) – train/validation/test
Датасет для downstream оценки (опционально)CLUTRR (логические цепочки) или LongBench (обобщение длинных текстов)
Средство логированияWeights & Biases (wandb) или MLflow – завести проект mamba_attention_hybrid
Хранилище результатовЛокальная папка ./experiments/ с чекпоинтами и метриками

Если нет реального инструмента — симулируем:

  1. Если mamba-ssm не устанавливается (например, из-за CUDA): реализуем упрощённый Mamba-блок на чистом PyTorch.

    • Использовать готовый код из репозитория mamba (файл mamba_ssm/modules/mamba_simple.py) – скопировать как mamba_block.py.
    • Заменить cuda-зависимые операции (например, selective scan) на torch.einsum + сканирование через torch.cumsum (потеря скорости, но сохранение функциональности).
    • Проверить корректность на случайном тензоре длины 512.
  2. Если нет GPU: все запуски проводить на CPU с reduced batch size (1-2) и ограниченной длиной (128 токенов), обучение 1-2 эпохи только для проверки сходимости.


3. Технологический стек

КомпонентИнструментыНазначение
ФреймворкPyTorch 2.2Обучение и инференс
Mamba-реализацияmamba-ssm (v1.2+) или самописная (sim_mamba)SSP-слой для локального контекста
Трансформер-блокиHuggingFace Transformers (GPT-2Block)Attention-слои для глобального контекста
Датасетыdatasets (WikiText-103)Тренировка и валидация
ТокенизаторGPT-2 tokenizer (HuggingFace)Токенизация текста
ЛогированиеWeights & Biases (wandb)Отслеживание метрик, лосса, perplexity
ГиперпараметрыHydra / argparseУправление конфигурацией
СредаLinux (Ubuntu 22.04), Python 3.10Выполнение экспериментов
МетрикиPerplexity (скользящая), loss, speed (tokens/sec)Качество и эффективность

4. Этапы выполнения

Этап 1: Подготовка окружения и данных (2-3 часа)

Действия

  1. Создать виртуальное окружение и установить зависимости:

    python -m venv venv
    source venv/bin/activate
    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    pip install transformers datasets wandb mamba-ssm hydra-core
    

    Если mamba-ssm не ставится – установить самописную реализацию из mamba_block.py (без cuda-оптимизаций).

  2. Загрузить датасет и токенизатор

    from datasets import load_dataset
    from transformers import GPT2Tokenizer
    dataset = load_dataset("wikitext", "wikitext-103-raw-v1")
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
  3. Подготовить DataLoader

    • Объединить все тексты в один длинный поток (concatenate с \n).
    • Разделить на чанки по block_size=1024 токенов.
    • Создать Dataset с перекрытием (stride=512) для обучения.
    • Базовая конфигурация: batch_size=4 (GPU) / 1 (CPU), num_workers=2.
  4. Записать baseline (чистый трансформер):

    • Взять GPT-2 small из HuggingFace: model = GPT2LMHeadModel.from_pretrained("gpt2").
    • Обучить 1 эпоху на WikiText-103 (или зафиксировать pretrained perplexity на валидации).

Ожидаемый результат этапа Готовый train_loader, val_loader, baseline модель (GPT-2) с измеренным perplexity на валидации (~30-35 для pretrained). Результат записан в baseline_perplexity.json.

Этап 2: Проектирование и реализация гибридной архитектуры (4-6 часов)

Действия

  1. Спроектировать конфигурацию (Hydra YAML):

    # config.yaml
    model:
      vocab_size: 50257
      hidden_size: 768
      num_layers: 12
      hybrid_every: 3  # каждый 3-й слой – Attention, остальные Mamba
      attn_layers: [3, 6, 9, 12]  # список позиций (1-indexed)
    training:
      lr: 3e-4
      epochs: 5
      block_size: 1024
    
  2. Написать класс HybridModel

  3. Реализовать самописный Mamba (если нет mamba-ssm):

    class SimpleMamba(nn.Module):
        def __init__(self, d_model, d_state=16, expand=2):
            super().__init__()
            # Параметры: A_log, D, dt_proj, in_proj, x_proj, out_proj
            self.in_proj = nn.Linear(d_model, expand * d_model)
            self.ssm = SSM(d_model * expand, d_state)  # самописный selective scan
            self.out_proj = nn.Linear(expand * d_model, d_model)
    
        def forward(self, x):
            # x: (B, L, D)
            residual = x
            x = self.in_proj(x)
            x = self.ssm(x)
            x = self.out_proj(x)
            return x + residual
    

    (Подробную реализацию selective scan см. в attached ssm_scan.py – предоставить в репозитории).

  4. Проверить суммарное число параметров

    • Гибридная модель должна иметь не больше параметров, чем GPT-2 small (124M). Если больше – уменьшить hidden_size или expand.

Ожидаемый результат этапа Класс HybridModel (файл model.py), проверка на синтетическом батче – loss не NaN, forward проходит.

Этап 3: Обучение и логирование (6-8 часов на GPU / 20+ часов на CPU)

Действия

  1. Интегрировать с wandb

    • Инициализация: wandb.init(project="mamba_attention_hybrid", name="hybrid_v1").
    • Логировать: loss (train/val), perplexity, learning rate, tokens_per_sec.
  2. Написать training loop

    • Использовать AdamW с lr=3e-4, weight_decay=0.1.
    • Планировщик: cosine decay, warmup steps 500.
    • Сохранять чекпоинты каждые 2500 шагов.
    • Валидация каждые 500 шагов на 200 батчах.
  3. Параллельно обучить baseline (если не был предобучен):

    • Та же процедура, те же гиперпараметры, модель – GPT2LMHeadModel (не предобученная, случайная инициализация).
  4. Сравнивать динамику

    • Оба эксперимента фиксировать в wandb с тегом baseline / hybrid.
    • Замерить perplexity на валидации после 1, 3, 5 эпох.

Ожидаемый результат этапа Два набора чекпоинтов (baseline, hybrid) и wandb дашборд с графиками loss и perplexity. Гибридная модель должна показывать меньший perplexity после 3-5 эпох (разница >1 пункт).

Этап 4: Оценка на downstream задаче (2-3 часа)

Действия

  1. Выбрать задачу

    • CLUTRR (логические цепочки) – нужен инференс с генерированными ответами.
    • Использовать нулевой выстрел (zero-shot) с готовым промптом: "Given the story: {story}\nQ: {question}\nA:".
  2. Сгенерировать ответы для обеих моделей (baseline и hybrid) с temperature=0.7, max_new_tokens=50.

  3. Оценить exact match (EM) и F1 на тестовой части CLUTRR (200 примеров).

  4. Записать метрики и сравнить

    • Ожидание: hybrid модель имеет EM выше на 2-5% из-за лучшего моделирования локальных связей (Mamba помогает с цепочками шагов).

Ожидаемый результат этапа Таблица с метриками, подтверждающая преимущество hybrid.

Этап 5: Анализ и документирование (1-2 часа)

Действия

  1. Построить графики

    • Влияние отношения числа Attention/Mamba слоёв на perplexity (эксперимент с hybrid_every 2,3,4).
    • Скорость инференса (tokens/sec) при длине последовательности 2048.
  2. Записать выводы в README

    • Какая конфигурация дала лучший perplexity?
    • Сколько памяти занимает каждая модель?
    • Устойчивость к длинным контекстам (Mamba даёт O(1) память на шаг, Attention O(L^2)).
  3. Создать финальный отчёт

    • Markdown-файл hybrid_report.md с метриками, графиками и рекомендациями.

Ожидаемый результат этапа Отчёт в experiments/hybrid_report.md, чекпоинты лучшей модели в best_model.pt.


5. Критерии приемки (Definition of Done)

  • Код гибридной модели реализован, проходит unit-тест на синтетическом батче (loss < 10 без NaN).
  • Baseline (чистый трансформер) обучен на тех же данных и гиперпараметрах.
  • Гибридная модель достигла perplexity ниже baseline на валидации WikiText-103 (разница >1.0 после 5 эпох).
  • На downstream задаче (CLUTRR) hybrid показал EM >= baseline (p-value < 0.05).
  • Все эксперименты залогированы в wandb (или TensorBoard) с воспроизводимыми конфигами.
  • Собрана статистика скорости инференса: hybrid быстрее baseline при длине > 1024 токенов (или сопоставим).
  • Результаты задокументированы в hybrid_report.md с таблицей и графиками.
  • Конфигурации hyperparameters сохранены в configs/ (Hydra).

6. Ожидаемый результат

Основной артефакт Папка experiments/mamba_attention_hybrid/ с содержимым:

Файл/ПапкаОписание
model.pyКласс HybridModel и вспомогательные блоки (Mamba, Attention)
train.pyСкрипт обучения с аргументами (Hydra)
configs/Конфигурации для разных вариантов (hybrid_every=2,3,4)
best_model.ptЧекпоинт лучшей гибридной модели
baseline_model.ptЧекпоинт baseline
hybrid_report.mdОтчёт с метриками, графиками, выводами
logs/Логи wandb / tensorboard
metric_table.csvPerplexity и EM для всех конфигураций

Опциональные результаты Демо-ноутбук с инференсом (загрузка модели, генерация текста).


7. Возможные сложности и их решение

СложностьРешение
Mamba-ssm не устанавливается (отсутствие CUDA)Реализовать упрощённый Mamba с torch.einsum и cumsum – потеря скорости, но функциональность сохранена.
Нестабильное обучение (loss NaN)Уменьшить learning rate (1e-4), добавить gradient clipping (max_norm=1.0), проверить нормализацию (LayerNorm после каждого блока).
Гибридная модель больше памятиУменьшить hidden_size до 512 или expand=1; использовать gradient checkpointing (torch.utils.checkpoint).
Отсутствие downstream набора CLUTRRЗаменить на любой QA (например, BoolQ) или синтетическую задачу (reverse строки).
Недостаточно времени на CPU обучениеОграничить тренировочные данные 1% датасета, 1 эпоху, оценить только динамику loss.
Сравнение с pretrained GPT-2 нечестноОбучать baseline с нуля на тех же данных (random init).

8. Бюджет времени (оценка)

ЭтапВремя (часы)
Этап 1: Подготовка окружения и данных2-3
Этап 2: Реализация гибридной архитектуры4-6
Этап 3: Обучение и логирование6-8 (GPU) / 20+ (CPU)
Этап 4: Оценка на downstream2-3
Этап 5: Анализ и документирование1-2
Итого15-22 (GPU) / 30+ (CPU)

Примечание для первого раза Добавьте резерв 50% времени на наладку Mamba и отладку стабильности.


9. Связанные вопросы из базы знаний

ВопросТема
42Архитектура трансформера (Attention, positional encoding)
78State Space Models: теория (S4, Mamba)
120Selective scan и параллельные префиксные суммы
145Сравнение Transformer vs SSM по сложности
200Оценка языковых моделей (perplexity, BPC)
305Long-context transformers (Longformer, BigBird)
410Гибридные архитектуры (MambaFormer, Jamba)
520Настройка гиперпараметров для LM (learning rate schedule)
610Downstream evaluation (GLUE, SuperGLUE)
755Gradient checkpointing и снижение памяти

10. Чек-лист самопроверки

  • Я реализовал гибридную модель, которая комбинирует Mamba и Attention, и она корректно работает на синтетическом тесте.
  • Я обучил baseline (чистый трансформер) с нуля с теми же гиперпараметрами.
  • Я сравнил perplexity на валидации и убедился, что hybrid лучше (разница >1).
  • Я провёл downstream оценку и получил метрики выше или равные baseline.
  • Я записал конфигурацию эксперимента, лосс и параметры в wandb / отчёт.
  • Я проверил, что скорость инференса не сильно упала (или улучшилась на длинных последовательностях).
  • Я подготовил README с выводами и воспроизводимыми командами.