中文翻译暂不可用,显示俄语原文。

Развернуть Mamba-2 локально и сравнить perplexity с Llama-3-8B на длинном контексте

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Развернуть Mamba-2 локально и сравнить perplexity с Llama-3-8B на длинном контексте

1. Цель задачи

Практическое знакомство с архитектурой State Space Models (SSM) на примере Mamba-2. Необходимо развернуть предобученную модель Mamba-2 локально, измерить её perplexity на последовательностях длиной от 8K до 32K токенов и сравнить с эталонной Transformer-моделью Llama-3-8B. Дополнительно фиксируется скорость инференса и потребление памяти.

Ключевой результат На длинных контекстах Mamba-2 показывает лучшее время инференса (минимум в 2 раза быстрее) при perplexity, не превышающем Llama-3-8B более чем на 5%.

2. Исходные данные

Что нужноОткуда взять
Модель Mamba-2 (предобученная, 2.8B параметров)Репозиторий state-spaces/mamba на GitHub
Модель Llama-3-8B (предобученная)HuggingFace: meta-llama/Meta-Llama-3-8B
Токенизатор для Llama-3-8BHuggingFace вместе с моделью
Токенизатор для Mamba-2EleutherAI/gpt-neox-20b (используется в Mamba)
Тестовый датасет длинных контекстовPG-19 (Project Gutenberg), выборка документов >32K токенов
Среда с GPU (минимум 16GB VRAM)Локальная / облачная (Colab Pro, RunPod, Lambda)

Если нет реального GPU с 24GB+ VRAM — симулируем:

  1. Используем меньшие версии моделей (Mamba-790M и Llama-3-8B в 4-битной квантованной версии через bitsandbytes).
  2. Ограничиваем длину контекста до 4K токенов.
  3. Замеряем время на CPU в режиме debug (но тогда скорость не будет показательной).

3. Технологический стек

КомпонентИнструментыНазначение
Язык программированияPython 3.10+Реализация скриптов
Фреймворк глубокого обученияPyTorch 2.1+Вычисления на GPU
Mamba-2Репозиторий mamba (pip install mamba-ssm)Загрузка и инференс Mamba-2
TransformersHuggingFace transformers 4.38+Загрузка Llama-3-8B, токенизация
Квантование (опционально)bitsandbytes4-битная загрузка Llama для экономии VRAM
Бенчмаркингtime, torch.cuda.Event, psutilЗамер времени и памяти
Датасетdatasets (HuggingFace)PG-19
Логированиеwandb или простой JSONСохранение метрик

4. Этапы выполнения

Этап 1: Настройка окружения и установка Mamba-2 (30 минут)

Действия

  1. Создать виртуальное окружение Python 3.10, активировать.
  2. Установить PyTorch с поддержкой CUDA (версия, совместимая с драйвером).
  3. Установить Mamba-2 из официального репозитория:
    pip install mamba-ssm[causal-conv1d]
    
    Если не собирается causal-conv1d, установить без него: pip install mamba-ssm
  4. Установить библиотеки: transformers, datasets, accelerate, bitsandbytes, einops.
  5. Клонировать репозиторий Mamba для тестовых скриптов (опционально):
    git clone https://github.com/state-spaces/mamba
    
  6. Проверить доступность GPU: python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_name(0))"

Ожидаемый результат этапа Установлены все зависимости, GPU определяется, Mamba-2 импортируется без ошибок.

Этап 2: Загрузка моделей и токенизаторов (30 минут)

Действия

  1. Загрузить Mamba-2 (например, версию 2.8B):
    from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
    import torch
    model_mamba = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device="cuda", dtype=torch.float16)
    
  2. Загрузить Llama-3-8B через transformers (с квантованием, если VRAM < 24GB):
    from transformers import AutoModelForCausalLM, AutoTokenizer
    model_llama = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Meta-Llama-3-8B",
        device_map="auto",
        load_in_4bit=True,  # если памяти мало
        torch_dtype=torch.float16
    )
    
  3. Загрузить токенизаторы:
    tokenizer_mamba = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
    
  4. Убедиться, что обе модели корректно работают на коротком тексте.

Ожидаемый результат этапа Обе модели загружены, способны генерировать текст (хотя бы один токен), потребление GPU не превышает доступный объём.

Этап 3: Подготовка тестового датасета длинных контекстов (30 минут)

Действия

  1. Загрузить датасет PG-19:
    from datasets import load_dataset
    ds = load_dataset("pg19", split="test")
    
  2. Выбрать 10 документов длиной не менее 35K токенов (после токенизации Llama-токенизатором).
    • Написать функцию фильтрации, которая токенизирует первые 100 символов, оценивает длину.
    • Для каждого документа обрезать до 32K токенов (первые 32K).
  3. Создать три набора длин: 8K, 16K, 32K токенов (можно обрезать те же документы).
  4. Сохранить токенизированные последовательности в формате .pt (torch tensors) для ускорения повторных замеров.

Ожидаемый результат этапа Подготовлено 10 последовательностей для каждой длины (всего 30 тензоров).

Этап 4: Инференс и сбор метрик (2 часа)

Действия

  1. Написать функцию compute_perplexity_and_speed(model, tokenizer, input_ids, model_name):
    • Засечь время с помощью torch.cuda.Event.
    • Прогнать модель в режиме torch.no_grad(), посчитать кросс-энтропийный loss (без редукции).
    • Рассчитать perplexity = exp(loss).
    • Зафиксировать пиковое потребление памяти GPU (через torch.cuda.max_memory_allocated()).
  2. Для каждой длины (8K, 16K, 32K) и для каждой модели выполнить:
    • Прогнать все 10 последовательностей.
    • Усреднить метрики по запускам.
    • Убедиться, что последовательности укладываются в VRAM (для Llama-8B на 32K может потребоваться gradient checkpointing или offloading).
  3. Сохранить результаты в CSV или JSON:
    {
      "model": "mamba-2.8b",
      "context_length": 8192,
      "perplexity": 12.34,
      "time_seconds": 2.1,
      "max_memory_mb": 8500
    }
    

Ожидаемый результат этапа Собраны численные данные по perplexity, времени и памяти для обеих моделей на трёх длинах контекста.

Этап 5: Анализ и визуализация результатов (30 минут)

Действия

  1. Построить таблицу сравнения (пример):
МодельДлина контекстаPerplexityВремя (сек)Память (MB)
Mamba-2.8B8K14.21.54500
Llama-3-8B8K13.82.88000
...............
  1. Вычислить относительное ускорение: time_llama / time_mamba.
  2. Вычислить относительную разницу в perplexity: (ppl_mamba - ppl_llama) / ppl_llama * 100%.
  3. Если met условие "Mamba быстрее в 2+ раза и perplexity хуже не более чем на 5%" — задача выполнена.
  4. Визуализировать (опционально): bar chart perplexity vs длина, и времени vs длина.

Ожидаемый результат этапа График/таблица, демонстрирующая выполнение ключевого результата.

5. Критерии приемки (Definition of Done)

  • Mamba-2 загружена и выполняет инференс локально (CPU/GPU) без ошибок.
  • Llama-3-8B загружена (возможно, в 4-битном формате) и также выполняет инференс.
  • Подготовлен тестовый датасет из 10 документов длиной 8K, 16K, 32K токенов.
  • Измерены perplexity, время инференса и пиковая память для обеих моделей на всех длинах.
  • Mamba-2 показала время инференса как минимум в 2 раза меньше, чем Llama-3-8B на каждом контексте.
  • Perplexity Mamba-2 не превышает perplexity Llama-3-8B более чем на 5% (относительно) на всех длинах.
  • Результаты сохранены в JSON и представлены в виде таблицы с выводом выполнения условия.
  • Код воспроизводим: есть скрипт run_benchmark.py, который можно запустить заново.

6. Ожидаемый результат

Основной артефакт Папка с файлами:

  • run_benchmark.py — главный скрипт бенчмарка.
  • benchmark_results.json — собранные метрики.
  • comparison_table.md — отформатированная таблица с выводом.

Содержание benchmark_results.json

[
  {
    "model": "mamba-2.8b",
    "context_length": 8192,
    "perplexity": 14.23,
    "time_s": 1.52,
    "memory_mb": 4520
  },
  ...
]

Опциональные доп. результаты

  • Графики в PNG (perplexity vs context length, time vs context length).
  • Сравнение потребления памяти (stacked bar).
  • Заметки о проблемах при работе с длинными контекстами у Transformer (OOM).

7. Возможные сложности и их решение

СложностьРешение
Mamba-2 не компилируется (causal-conv1d)Установить pip install mamba-ssm без опции [causal-conv1d] — модель будет работать, но медленнее.
Llama-3-8B не влезает в VRAM (16GB) для 32K контекстаИспользовать load_in_4bit=True (quantization) и device_map="auto".
Результаты perplexity сильно различаются между запускамиФиксировать seed, использовать torch.inference_mode(), усреднять по 10 документам.
Время инференса измеряется неточноИспользовать torch.cuda.synchronize() до и после, исключить прогрев (первые запуски отбрасывать).
Нет подходящего датасетаВзять первые 10 глав из любой длинной книги в открытом доступе (например, "Война и мир" в plain text), токенизировать и обрезать.

8. Бюджет времени (оценка)

ЭтапВремя
Этап 1: Настройка окружения30 мин
Этап 2: Загрузка моделей30 мин
Этап 3: Подготовка датасета30 мин
Этап 4: Инференс и сбор метрик2 ч
Этап 5: Анализ и визуализация30 мин
Итого4 ч

Примечание Для первого раза может потребоваться до 6 часов из-за непредвиденных проблем с зависимостями или нехваткой памяти.

9. Связанные вопросы из базы знаний

ВопросТема
15Что такое State Space Model (SSM)?
18Как устроено внимание в Transformer?
24Сравнение скорости transformer vs SSM на длинных последовательностях
32Что такое perplexity и как её считать?
45Quantization моделей (bitsandbytes)
78Инференс больших моделей с ограниченным VRAM
103Токенизаторы: GPT-NeoX vs Llama
201Архитектура Mamba-2 (selective state space)
304PG-19 датасет для оценки долгосрочных зависимостей
512Gradient checkpointing для длинного контекста

10. Чек-лист самопроверки

  • Я точно следовал этапам и не пропустил ни один.
  • Обе модели загружены корректно и выдают осмысленные логиты на проверочном тексте.
  • Для каждой длины контекста я прогнал минимум 10 разных документов и усреднил метрики.
  • Я зафиксировал все результаты в JSON и убедился, что файл читается.
  • Я проверил выполнение ключевого условия: Mamba быстрее в 2+ раза, а perplexity в пределах 5% от Llama.