中文翻译暂不可用,显示俄语原文。
Реализовать selective scan (Mamba)
ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Реализовать selective scan (Mamba)
1. Цель задачи
Разработать CUDA-ядро для selective scan — ключевого оператора архитектуры Mamba (State Space Model). Реализация должна поддерживать эффективный последовательный проход по скрытому состоянию с селективным взвешиванием входных данных. Научиться запускать кастомный CUDA kernel через PyTorch, верифицировать корректность на синтетических данных и сравнить с эталонной реализацией (наивная Python-модель).
Ключевой результат Рабочий CUDA kernel для selective scan, интегрированный в PyTorch как пользовательская функция mamba_scan, корректность которой подтверждена численным тестом (разница <1e-5).
2. Исходные данные
| Что нужно | Откуда взять |
|---|---|
| Бумага Mamba (SSM) | https://arxiv.org/abs/2312.00752 |
| Исходный код Mamba (PyTorch reference) | https://github.com/state-spaces/[mamba](/wiki/Mamba) |
| CUDA-совместимый NVIDIA GPU (или эмуляция) | Локальная машина / Colab / облачный GPU |
| PyTorch 2.x | pip install torch |
| Компилятор CUDA (nvcc) | Установка CUDA Toolkit (рекомендуется 12.x) |
| Jupyter Notebook / Python script | Любая среда |
Если нет GPU — симулируем
- Установите PyTorch с CPU-версией (не требует CUDA).
- Напишите Python-функцию
selective_scan_cpuс циклами, реализующую ту же логику. - Код kernel (CUDA C++) можно отладить с помощью компилятора nvcc в режиме
--device-emu(эмуляция CUDA на CPU, устаревшая) или использовать Google Colab (бесплатный GPU T4). - Для проверки синтаксиса и псевдокода можно использовать CUDA Online Compiler](https://cuda.godbolt.org/).
3. Технологический стек
| Компонент | Инструменты | Назначение |
|---|---|---|
| Язык программирования | C++ (CUDA), Python | Реализация kernel и обёртки |
| Фреймворк | PyTorch 2.x | Forward/backward, интеграция |
| Компилятор | nvcc (CUDA Toolkit) | Компиляция .cu в .ptx / cubin |
| Биндинги | torch.utils.cpp_extension | Подключение кастомного CUDA кода |
| Тестирование | pytest, NumPy | Верификация результатов |
| Профилирование (опционально) | nvprof / Nsight Compute | Оптимизация производительности |
4. Этапы выполнения
Этап 1: Изучение теории и эталонной реализации (1–2 часа)
Действия
-
Прочитать разделы 1–3 статьи Mamba (до Selective Scan).
-
Ознакомиться с официальной реализацией
mamba_ssmна GitHub, найти функциюselective_scan_fn. -
Выписать математическое описание selective scan:
state_t = A_bar_t * state_{t-1} + B_bar_t * x_t y_t = C_t * state_t где A_bar_t = exp(delta_t * A) B_bar_t = delta_t * B delta_t = softplus(Delta_t) # learnable step -
Реализовать Python-прототип в
selective_scan_ref.py:import torch import torch.nn.functional as F def selective_scan_ref(u, delta, A, B, C, D, z=None): batch, dim, seqlen = u.shape delta = F.softplus(delta) # discrete A and B deltaA = torch.exp(delta.unsqueeze(-1) * A) deltaB = delta.unsqueeze(-1) * B # scan state = torch.zeros(batch, dim, device=u.device) y = torch.empty(batch, dim, seqlen, device=u.device) for t in range(seqlen): state = deltaA[..., t] * state + deltaB[..., t] * u[..., t] y[..., t] = (C[..., t] * state).sum(-1) + D * u[..., t] if z is not None: y = y * F.silu(z) return y -
Протестировать прототип на случайном тензоре.
Ожидаемый результат этапа Рабочая selective_scan_ref на Python; понимание всех математических операций.
Этап 2: Разработка CUDA-ядра (3–5 часов)
Действия
-
Создать файл
selective_scan_kernel.cu. -
Реализовать три версии kernel для сравнения:
- наивная: однопоточная (1 thread per sequence element) —
naive_scan_kernel - параллельная с warp-level: параллельный префикс-сумма в warp (не подходит, т.к. процесс рекуррентный) — покажите, что здесь нужен именно последовательный scan.
- оптимизированная с использованием shared memory (построчная обработка с блоками).
- наивная: однопоточная (1 thread per sequence element) —
-
Базовая сигнатура kernel:
__global__ void selective_scan_kernel( const float* u, // [batch, dim, seqlen] const float* delta, const float* A, const float* B, const float* C, float* y, int batch, int dim, int seqlen ) { // Каждый thread обрабатывает один элемент (b,d) на всём seqlen int idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx >= batch * dim) return; int b = idx / dim; int d = idx % dim; float state = 0.0f; for (int t = 0; t < seqlen; t++) { float u_t = u[(b*dim + d)*seqlen + t]; float delta_t = delta[(b*dim + d)*seqlen + t]; float A_val = A[d]; // один параметр на размерность float B_val = B[(b*dim + d)*seqlen + t]; // [b,d,t] если B обучаемый float C_val = C[(b*dim + d)*seqlen + t]; float deltaA = expf(delta_t * A_val); float deltaB = delta_t * B_val; state = deltaA * state + deltaB * u_t; y[(b*dim + d)*seqlen + t] = C_val * state; } } -
Выделить память, скопировать данные и запустить kernel.
-
Написать torch.autograd.Function для связывания kernel с PyTorch:
import torch from torch.utils.cpp_extension import load_inline cuda_source = open('selective_scan_kernel.cu').read() module = load_inline('selective_scan', cuda_sources=[cuda_source], functions=['selective_scan_forward'], verbose=True) class SelectiveScanFunction(torch.autograd.Function): @staticmethod def forward(ctx, u, delta, A, B, C, D): # ... вызвать модуль.selective_scan_forward pass # backward (пропущен для упрощения, но может быть добавлен позже)
Ожидаемый результат этапа CUDA-функция selective_scan_forward, скомпилированная и доступная из Python, которая возвращает корректные выходные тензоры (проверить совпадение с эталоном на малых размерах).
Этап 3: Интеграция и тестирование (1–2 часа)
Действия
- Написать тестовый скрипт
test_selective_scan.py. - Сравнить вывод CUDA kernel с эталонной Python-реализацией для разных размеров:
- batch = 2, dim = 4, seqlen = 8
- batch = 1, dim = 32, seqlen = 256
- batch = 8, dim = 128, seqlen = 64
- Замерить абсолютную разницу:
(cuda_out - ref_out).abs().max()— должно быть <1e-5 (float). - Проверить градиенты (если backward реализован): через torch.autograd.gradcheck.
- Создать benchmark: сравнить скорость CUDA kernel vs Python-цикл.
Ожидаемый результат этапа Все тесты пройдены; выведены замеры времени.
Этап 4: Оптимизация (опциональный, 1–2 часа)
Действия
- Использовать профилировщик Nsight Compute для выявления узких мест.
- Оптимизировать доступ к памяти: использовать
coalescedчтение/запись (переупаковка тензоров в layout[seqlen, batch, dim]). - Использовать
__ldg(cached read) и__syncthreadsпри необходимости. - Рассмотреть использование
float4для векторной загрузки. - Написать второй kernel с этими оптимизациями и сравнить производительность.
Ожидаемый результат этапа Оптимизированная версия минимум в 2 раза быстрее наивной.
5. Критерии приемки (Definition of Done)
- CUDA-код компилируется без ошибок и линкуется с PyTorch.
- Функция
selective_scan_forwardкорректно обрабатывает граничные случаи: dim=1, seqlen=1, batch=0 (недопустимо, но хотя бы не краш). - Разница между выходом CUDA kernel и эталонной Python-реализацией < 1e-5 по всем элементам (float32).
- Есть автоматизированный тест
pytest. - Документированы ограничения kernel (например, максимальный seqlen из-за регистров, только float32).
- Реализован backward (опционально, но желательно для практического использования) — проверено
gradcheck. - Время выполнения CUDA kernel меньше, чем у Python-цикла в 100+ раз для seqlen=2048, dim=256, batch=1.
- Код выложен в Git-репозиторий с читаемым README.
6. Ожидаемый результат
Основной артефакт Python-модуль mamba_scan_cuda.cpython-*.so (или .ptx), который экспортирует функцию selective_scan_forward(u, delta, A, B, C, D) -> y.
Содержимое репозитория
.
├── selective_scan_kernel.cu # CUDA kernel
├── selective_scan_wrapper.py # PyTorch autograd function
├── test_selective_scan.py # тесты
├── benchmark.py # замеры производительности
├── selective_scan_ref.py # эталон на Python
└── README.md # инструкция по сборке
Дополнительно График сравнения времени выполнения vs seqlen.
7. Возможные сложности и их решение
| Сложность | Решение |
|---|---|
Ошибка компиляции CUDA: expected an identifier | Проверить макросы; убедиться, что __global__ и __device__ корректны. Использовать простой стартовый шаблон. |
| Kernel долго считает или не сходится | Использовать printf в kernel для отладки (cuda-gdb). Проверить layout памяти – тензоры могут быть в порядке C vs row-major. |
| Неправильный broadcast A (параметры размерности) | A обычно имеет форму (dim,) – нужно правильно применить exp(delta*A) с вещанием в kernel. |
| Memory access violation | Проверить off-by-one в индексах; использовать cuda-memcheck. |
| Gradient computation not implemented | Пока опустить backward; для задания достаточно forward + численной проверки. |
| Коллапс при больших seqlen (регистровый спилл) | Уменьшить количество временных переменных, использовать volatile или вынести часть в локальную память. |
8. Бюджет времени (оценка)
| Этап | Время |
|---|---|
| 1. Изучение теории и прототип | 1.5 часа |
| 2. Разработка CUDA-ядра | 4 часа |
| 3. Интеграция и тестирование | 2 часа |
| 4. Оптимизация (опционально) | 2 часа |
| Итого (основное) | 7.5 часов |
| Итого (с оптимизацией) | 9.5 часов |
Примечание Для первого раза допустимо потратить до 12 часов с учётом отладки.
9. Связанные вопросы из базы знаний
| Вопрос | Тема |
|---|---|
| 1 | Basic CUDA programming model |
| 47 | PyTorch custom CUDA extensions |
| 102 | State Space Models overview |
| 203 | Attention vs SSM tradeoffs |
| 310 | Writing fused kernels for recurrent operations |
| 415 | Profiling CUDA kernels (Nsight) |
| 520 | Mixed precision training with custom kernels |
| 627 | Numerical stability of exp() in long sequences |
| 734 | Parallel scan algorithms (prefix sum) |
| 845 | Gradient computation for linear recurrence |
10. Чек-лист самопроверки
- Я понимаю математику selective scan и могу объяснить её на доске за 5 минут.
- Моя CUDA-функция возвращает правильные значения на синтетическом тесте (совпадение с Python-эталоном с точностью 1e-5).
- Я написал хотя бы один автоматический тест на pytest.
- Я замерил скорость на seqlen=1024 и сравнил с наивной реализацией (разница > 100x).
- В README указаны инструкции по сборке и запуску тестов.
- Код оформлен согласно PEP8 (python) и CUDA C++ best practices.