中文翻译暂不可用,显示俄语原文。
Развернуть Mamba-2 локально и сравнить perplexity с Llama-3-8B на длинном контексте
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Развернуть Mamba-2 локально и сравнить perplexity с Llama-3-8B на длинном контексте
1. Цель задачи
Практическое знакомство с архитектурой State Space Models (SSM) на примере Mamba-2. Необходимо развернуть предобученную модель Mamba-2 локально, измерить её perplexity на последовательностях длиной от 8K до 32K токенов и сравнить с эталонной Transformer-моделью Llama-3-8B. Дополнительно фиксируется скорость инференса и потребление памяти.
Ключевой результат На длинных контекстах Mamba-2 показывает лучшее время инференса (минимум в 2 раза быстрее) при perplexity, не превышающем Llama-3-8B более чем на 5%.
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Модель Mamba-2 (предобученная, 2.8B параметров) | Репозиторий state-spaces/mamba на GitHub |
| Модель Llama-3-8B (предобученная) | HuggingFace: meta-llama/Meta-Llama-3-8B |
| Токенизатор для Llama-3-8B | HuggingFace вместе с моделью |
| Токенизатор для Mamba-2 | EleutherAI/gpt-neox-20b (используется в Mamba) |
| Тестовый датасет длинных контекстов | PG-19 (Project Gutenberg), выборка документов >32K токенов |
| Среда с GPU (минимум 16GB VRAM) | Локальная / облачная (Colab Pro, RunPod, Lambda) |
Если нет реального GPU с 24GB+ VRAM — симулируем:
- Используем меньшие версии моделей (Mamba-790M и Llama-3-8B в 4-битной квантованной версии через bitsandbytes).
- Ограничиваем длину контекста до 4K токенов.
- Замеряем время на CPU в режиме debug (но тогда скорость не будет показательной).
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Язык программирования | Python 3.10+ | Реализация скриптов |
| Фреймворк глубокого обучения | PyTorch 2.1+ | Вычисления на GPU |
| Mamba-2 | Репозиторий mamba (pip install mamba-ssm) | Загрузка и инференс Mamba-2 |
| Transformers | HuggingFace transformers 4.38+ | Загрузка Llama-3-8B, токенизация |
| Квантование (опционально) | bitsandbytes | 4-битная загрузка Llama для экономии VRAM |
| Бенчмаркинг | time, torch.cuda.Event, psutil | Замер времени и памяти |
| Датасет | datasets (HuggingFace) | PG-19 |
| Логирование | wandb или простой JSON | Сохранение метрик |
4. Этапы выполнения
Этап 1: Настройка окружения и установка Mamba-2 (30 минут)
Действия
- Создать виртуальное окружение Python 3.10, активировать.
- Установить PyTorch с поддержкой CUDA (версия, совместимая с драйвером).
- Установить Mamba-2 из официального репозитория:
Если не собирается causal-conv1d, установить без него: pip install mamba-ssmpip install mamba-ssm[causal-conv1d] - Установить библиотеки: transformers, datasets, accelerate, bitsandbytes,
einops. - Клонировать репозиторий Mamba для тестовых скриптов (опционально):
git clone https://github.com/state-spaces/mamba - Проверить доступность GPU: python -c "import torch; print(torch.cuda.is_available(), torch.cuda.get_device_name(0))"
Ожидаемый результат этапа Установлены все зависимости, GPU определяется, Mamba-2 импортируется без ошибок.
Этап 2: Загрузка моделей и токенизаторов (30 минут)
Действия
- Загрузить Mamba-2 (например, версию 2.8B):
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel import torch model_mamba = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b", device="cuda", dtype=torch.float16) - Загрузить Llama-3-8B через transformers (с квантованием, если VRAM < 24GB):
from transformers import AutoModelForCausalLM, AutoTokenizer model_llama = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3-8B", device_map="auto", load_in_4bit=True, # если памяти мало torch_dtype=torch.float16 ) - Загрузить токенизаторы:
tokenizer_mamba = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") tokenizer_llama = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") - Убедиться, что обе модели корректно работают на коротком тексте.
Ожидаемый результат этапа Обе модели загружены, способны генерировать текст (хотя бы один токен), потребление GPU не превышает доступный объём.
Этап 3: Подготовка тестового датасета длинных контекстов (30 минут)
Действия
- Загрузить датасет PG-19:
from datasets import load_dataset ds = load_dataset("pg19", split="test") - Выбрать 10 документов длиной не менее 35K токенов (после токенизации Llama-токенизатором).
- Написать функцию фильтрации, которая токенизирует первые 100 символов, оценивает длину.
- Для каждого документа обрезать до 32K токенов (первые 32K).
- Создать три набора длин: 8K, 16K, 32K токенов (можно обрезать те же документы).
- Сохранить токенизированные последовательности в формате
.pt(torch tensors) для ускорения повторных замеров.
Ожидаемый результат этапа Подготовлено 10 последовательностей для каждой длины (всего 30 тензоров).
Этап 4: Инференс и сбор метрик (2 часа)
Действия
- Написать функцию compute_perplexity_and_speed(model, tokenizer, input_ids, model_name):
- Засечь время с помощью torch.cuda.Event.
- Прогнать модель в режиме torch.no_grad(), посчитать кросс-энтропийный loss (без редукции).
- Рассчитать perplexity = exp(loss).
- Зафиксировать пиковое потребление памяти GPU (через torch.cuda.max_memory_allocated()).
- Для каждой длины (8K, 16K, 32K) и для каждой модели выполнить:
- Прогнать все 10 последовательностей.
- Усреднить метрики по запускам.
- Убедиться, что последовательности укладываются в VRAM (для Llama-8B на 32K может потребоваться gradient checkpointing или offloading).
- Сохранить результаты в CSV или JSON:
{ "model": "mamba-2.8b", "context_length": 8192, "perplexity": 12.34, "time_seconds": 2.1, "max_memory_mb": 8500 }
Ожидаемый результат этапа Собраны численные данные по perplexity, времени и памяти для обеих моделей на трёх длинах контекста.
Этап 5: Анализ и визуализация результатов (30 минут)
Действия
- Построить таблицу сравнения (пример):
| Модель | Длина контекста | Perplexity | Время (сек) | Память (MB) |
|---|---|---|---|---|
| Mamba-2.8B | 8K | 14.2 | 1.5 | 4500 |
| Llama-3-8B | 8K | 13.8 | 2.8 | 8000 |
| ... | ... | ... | ... | ... |
- Вычислить относительное ускорение:
time_llama / time_mamba. - Вычислить относительную разницу в perplexity:
(ppl_mamba - ppl_llama) / ppl_llama * 100%. - Если met условие "Mamba быстрее в 2+ раза и perplexity хуже не более чем на 5%" — задача выполнена.
- Визуализировать (опционально): bar chart perplexity vs длина, и времени vs длина.
Ожидаемый результат этапа График/таблица, демонстрирующая выполнение ключевого результата.
5. Критерии приемки (Definition of Done)
- Mamba-2 загружена и выполняет инференс локально (CPU/GPU) без ошибок.
- Llama-3-8B загружена (возможно, в 4-битном формате) и также выполняет инференс.
- Подготовлен тестовый датасет из 10 документов длиной 8K, 16K, 32K токенов.
- Измерены perplexity, время инференса и пиковая память для обеих моделей на всех длинах.
- Mamba-2 показала время инференса как минимум в 2 раза меньше, чем Llama-3-8B на каждом контексте.
- Perplexity Mamba-2 не превышает perplexity Llama-3-8B более чем на 5% (относительно) на всех длинах.
- Результаты сохранены в JSON и представлены в виде таблицы с выводом выполнения условия.
- Код воспроизводим: есть скрипт
run_benchmark.py, который можно запустить заново.
6. Ожидаемый результат
Основной артефакт Папка с файлами:
run_benchmark.py— главный скрипт бенчмарка.benchmark_results.json— собранные метрики.comparison_table.md— отформатированная таблица с выводом.
Содержание benchmark_results.json
[
{
"model": "mamba-2.8b",
"context_length": 8192,
"perplexity": 14.23,
"time_s": 1.52,
"memory_mb": 4520
},
...
]
Опциональные доп. результаты
- Графики в PNG (perplexity vs context length, time vs context length).
- Сравнение потребления памяти (stacked bar).
- Заметки о проблемах при работе с длинными контекстами у Transformer (OOM).
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
| Mamba-2 не компилируется (causal-conv1d) | Установить pip install mamba-ssm без опции [causal-conv1d] — модель будет работать, но медленнее. |
| Llama-3-8B не влезает в VRAM (16GB) для 32K контекста | Использовать load_in_4bit=True (quantization) и device_map="auto". |
| Результаты perplexity сильно различаются между запусками | Фиксировать seed, использовать torch.inference_mode(), усреднять по 10 документам. |
| Время инференса измеряется неточно | Использовать torch.cuda.synchronize() до и после, исключить прогрев (первые запуски отбрасывать). |
| Нет подходящего датасета | Взять первые 10 глав из любой длинной книги в открытом доступе (например, "Война и мир" в plain text), токенизировать и обрезать. |
8. Бюджет времени (оценка)
| Этап | Время |
|---|---|
| Этап 1: Настройка окружения | 30 мин |
| Этап 2: Загрузка моделей | 30 мин |
| Этап 3: Подготовка датасета | 30 мин |
| Этап 4: Инференс и сбор метрик | 2 ч |
| Этап 5: Анализ и визуализация | 30 мин |
| Итого | 4 ч |
Примечание Для первого раза может потребоваться до 6 часов из-за непредвиденных проблем с зависимостями или нехваткой памяти.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 15 | Что такое State Space Model (SSM)? |
| 18 | Как устроено внимание в Transformer? |
| 24 | Сравнение скорости transformer vs SSM на длинных последовательностях |
| 32 | Что такое perplexity и как её считать? |
| 45 | Quantization моделей (bitsandbytes) |
| 78 | Инференс больших моделей с ограниченным VRAM |
| 103 | Токенизаторы: GPT-NeoX vs Llama |
| 201 | Архитектура Mamba-2 (selective state space) |
| 304 | PG-19 датасет для оценки долгосрочных зависимостей |
| 512 | Gradient checkpointing для длинного контекста |
10. Чек-лист самопроверки
- Я точно следовал этапам и не пропустил ни один.
- Обе модели загружены корректно и выдают осмысленные логиты на проверочном тексте.
- Для каждой длины контекста я прогнал минимум 10 разных документов и усреднил метрики.
- Я зафиксировал все результаты в JSON и убедился, что файл читается.
- Я проверил выполнение ключевого условия: Mamba быстрее в 2+ раза, а perplexity в пределах 5% от Llama.