Реализовать selective scan (Mamba)

ТЕХНИЧЕСКОЕ ЗАДАНИЕ: Реализовать selective scan (Mamba)

1. Цель задачи

Разработать CUDA-ядро для selective scan — ключевого оператора архитектуры Mamba (State Space Model). Реализация должна поддерживать эффективный последовательный проход по скрытому состоянию с селективным взвешиванием входных данных. Научиться запускать кастомный CUDA kernel через PyTorch, верифицировать корректность на синтетических данных и сравнить с эталонной реализацией (наивная Python-модель).

Ключевой результат Рабочий CUDA kernel для selective scan, интегрированный в PyTorch как пользовательская функция mamba_scan, корректность которой подтверждена численным тестом (разница <1e-5).


2. Исходные данные

Что нужноОткуда взять
Бумага Mamba (SSM)https://arxiv.org/abs/2312.00752
Исходный код Mamba (PyTorch reference)https://github.com/state-spaces/[mamba](/wiki/Mamba)
CUDA-совместимый NVIDIA GPU (или эмуляция)Локальная машина / Colab / облачный GPU
PyTorch 2.xpip install torch
Компилятор CUDA (nvcc)Установка CUDA Toolkit (рекомендуется 12.x)
Jupyter Notebook / Python scriptЛюбая среда

Если нет GPU — симулируем

  1. Установите PyTorch с CPU-версией (не требует CUDA).
  2. Напишите Python-функцию selective_scan_cpu с циклами, реализующую ту же логику.
  3. Код kernel (CUDA C++) можно отладить с помощью компилятора nvcc в режиме --device-emu (эмуляция CUDA на CPU, устаревшая) или использовать Google Colab (бесплатный GPU T4).
  4. Для проверки синтаксиса и псевдокода можно использовать CUDA Online Compiler](https://cuda.godbolt.org/).

3. Технологический стек

КомпонентИнструментыНазначение
Язык программированияC++ (CUDA), PythonРеализация kernel и обёртки
ФреймворкPyTorch 2.xForward/backward, интеграция
Компиляторnvcc (CUDA Toolkit)Компиляция .cu в .ptx / cubin
Биндингиtorch.utils.cpp_extensionПодключение кастомного CUDA кода
Тестированиеpytest, NumPyВерификация результатов
Профилирование (опционально)nvprof / Nsight ComputeОптимизация производительности

4. Этапы выполнения

Этап 1: Изучение теории и эталонной реализации (1–2 часа)

Действия

  1. Прочитать разделы 1–3 статьи Mamba (до Selective Scan).

  2. Ознакомиться с официальной реализацией mamba_ssm на GitHub, найти функцию selective_scan_fn.

  3. Выписать математическое описание selective scan:

    state_t = A_bar_t * state_{t-1} + B_bar_t * x_t
    y_t     = C_t * state_t
    
    где A_bar_t = exp(delta_t * A)
        B_bar_t = delta_t * B
        delta_t = softplus(Delta_t)  # learnable step
    
  4. Реализовать Python-прототип в selective_scan_ref.py:

    import torch
    import torch.nn.functional as F
    
    def selective_scan_ref(u, delta, A, B, C, D, z=None):
        batch, dim, seqlen = u.shape
        delta = F.softplus(delta)
        # discrete A and B
        deltaA = torch.exp(delta.unsqueeze(-1) * A)
        deltaB = delta.unsqueeze(-1) * B
        # scan
        state = torch.zeros(batch, dim, device=u.device)
        y = torch.empty(batch, dim, seqlen, device=u.device)
        for t in range(seqlen):
            state = deltaA[..., t] * state + deltaB[..., t] * u[..., t]
            y[..., t] = (C[..., t] * state).sum(-1) + D * u[..., t]
        if z is not None:
            y = y * F.silu(z)
        return y
    
  5. Протестировать прототип на случайном тензоре.

Ожидаемый результат этапа Рабочая selective_scan_ref на Python; понимание всех математических операций.


Этап 2: Разработка CUDA-ядра (3–5 часов)

Действия

  1. Создать файл selective_scan_kernel.cu.

  2. Реализовать три версии kernel для сравнения:

    • наивная: однопоточная (1 thread per sequence element) — naive_scan_kernel
    • параллельная с warp-level: параллельный префикс-сумма в warp (не подходит, т.к. процесс рекуррентный) — покажите, что здесь нужен именно последовательный scan.
    • оптимизированная с использованием shared memory (построчная обработка с блоками).
  3. Базовая сигнатура kernel:

    __global__ void selective_scan_kernel(
        const float* u,   // [batch, dim, seqlen]
        const float* delta,
        const float* A,
        const float* B,
        const float* C,
        float* y,
        int batch, int dim, int seqlen
    ) {
        // Каждый thread обрабатывает один элемент (b,d) на всём seqlen
        int idx = blockIdx.x * blockDim.x + threadIdx.x;
        if (idx >= batch * dim) return;
        int b = idx / dim;
        int d = idx % dim;
        float state = 0.0f;
        for (int t = 0; t < seqlen; t++) {
            float u_t = u[(b*dim + d)*seqlen + t];
            float delta_t = delta[(b*dim + d)*seqlen + t];
            float A_val = A[d];           // один параметр на размерность
            float B_val = B[(b*dim + d)*seqlen + t]; // [b,d,t] если B обучаемый
            float C_val = C[(b*dim + d)*seqlen + t];
            float deltaA = expf(delta_t * A_val);
            float deltaB = delta_t * B_val;
            state = deltaA * state + deltaB * u_t;
            y[(b*dim + d)*seqlen + t] = C_val * state;
        }
    }
    
  4. Выделить память, скопировать данные и запустить kernel.

  5. Написать torch.autograd.Function для связывания kernel с PyTorch:

    import torch
    from torch.utils.cpp_extension import load_inline
    
    cuda_source = open('selective_scan_kernel.cu').read()
    module = load_inline('selective_scan', cuda_sources=[cuda_source],
                         functions=['selective_scan_forward'], verbose=True)
    
    class SelectiveScanFunction(torch.autograd.Function):
        @staticmethod
        def forward(ctx, u, delta, A, B, C, D):
            # ... вызвать модуль.selective_scan_forward
            pass
        # backward (пропущен для упрощения, но может быть добавлен позже)
    

Ожидаемый результат этапа CUDA-функция selective_scan_forward, скомпилированная и доступная из Python, которая возвращает корректные выходные тензоры (проверить совпадение с эталоном на малых размерах).


Этап 3: Интеграция и тестирование (1–2 часа)

Действия

  1. Написать тестовый скрипт test_selective_scan.py.
  2. Сравнить вывод CUDA kernel с эталонной Python-реализацией для разных размеров:
    • batch = 2, dim = 4, seqlen = 8
    • batch = 1, dim = 32, seqlen = 256
    • batch = 8, dim = 128, seqlen = 64
  3. Замерить абсолютную разницу: (cuda_out - ref_out).abs().max() — должно быть <1e-5 (float).
  4. Проверить градиенты (если backward реализован): через torch.autograd.gradcheck.
  5. Создать benchmark: сравнить скорость CUDA kernel vs Python-цикл.

Ожидаемый результат этапа Все тесты пройдены; выведены замеры времени.


Этап 4: Оптимизация (опциональный, 1–2 часа)

Действия

  1. Использовать профилировщик Nsight Compute для выявления узких мест.
  2. Оптимизировать доступ к памяти: использовать coalesced чтение/запись (переупаковка тензоров в layout [seqlen, batch, dim]).
  3. Использовать __ldg (cached read) и __syncthreads при необходимости.
  4. Рассмотреть использование float4 для векторной загрузки.
  5. Написать второй kernel с этими оптимизациями и сравнить производительность.

Ожидаемый результат этапа Оптимизированная версия минимум в 2 раза быстрее наивной.


5. Критерии приемки (Definition of Done)

  • CUDA-код компилируется без ошибок и линкуется с PyTorch.
  • Функция selective_scan_forward корректно обрабатывает граничные случаи: dim=1, seqlen=1, batch=0 (недопустимо, но хотя бы не краш).
  • Разница между выходом CUDA kernel и эталонной Python-реализацией < 1e-5 по всем элементам (float32).
  • Есть автоматизированный тест pytest.
  • Документированы ограничения kernel (например, максимальный seqlen из-за регистров, только float32).
  • Реализован backward (опционально, но желательно для практического использования) — проверено gradcheck.
  • Время выполнения CUDA kernel меньше, чем у Python-цикла в 100+ раз для seqlen=2048, dim=256, batch=1.
  • Код выложен в Git-репозиторий с читаемым README.

6. Ожидаемый результат

Основной артефакт Python-модуль mamba_scan_cuda.cpython-*.so (или .ptx), который экспортирует функцию selective_scan_forward(u, delta, A, B, C, D) -> y.

Содержимое репозитория

.
├── selective_scan_kernel.cu    # CUDA kernel
├── selective_scan_wrapper.py   # PyTorch autograd function
├── test_selective_scan.py      # тесты
├── benchmark.py                # замеры производительности
├── selective_scan_ref.py       # эталон на Python
└── README.md                   # инструкция по сборке

Дополнительно График сравнения времени выполнения vs seqlen.


7. Возможные сложности и их решение

СложностьРешение
Ошибка компиляции CUDA: expected an identifierПроверить макросы; убедиться, что __global__ и __device__ корректны. Использовать простой стартовый шаблон.
Kernel долго считает или не сходитсяИспользовать printf в kernel для отладки (cuda-gdb). Проверить layout памяти – тензоры могут быть в порядке C vs row-major.
Неправильный broadcast A (параметры размерности)A обычно имеет форму (dim,) – нужно правильно применить exp(delta*A) с вещанием в kernel.
Memory access violationПроверить off-by-one в индексах; использовать cuda-memcheck.
Gradient computation not implementedПока опустить backward; для задания достаточно forward + численной проверки.
Коллапс при больших seqlen (регистровый спилл)Уменьшить количество временных переменных, использовать volatile или вынести часть в локальную память.

8. Бюджет времени (оценка)

ЭтапВремя
1. Изучение теории и прототип1.5 часа
2. Разработка CUDA-ядра4 часа
3. Интеграция и тестирование2 часа
4. Оптимизация (опционально)2 часа
Итого (основное)7.5 часов
Итого (с оптимизацией)9.5 часов

Примечание Для первого раза допустимо потратить до 12 часов с учётом отладки.


9. Связанные вопросы из базы знаний

ВопросТема
1Basic CUDA programming model
47PyTorch custom CUDA extensions
102State Space Models overview
203Attention vs SSM tradeoffs
310Writing fused kernels for recurrent operations
415Profiling CUDA kernels (Nsight)
520Mixed precision training with custom kernels
627Numerical stability of exp() in long sequences
734Parallel scan algorithms (prefix sum)
845Gradient computation for linear recurrence

10. Чек-лист самопроверки

  • Я понимаю математику selective scan и могу объяснить её на доске за 5 минут.
  • Моя CUDA-функция возвращает правильные значения на синтетическом тесте (совпадение с Python-эталоном с точностью 1e-5).
  • Я написал хотя бы один автоматический тест на pytest.
  • Я замерил скорость на seqlen=1024 и сравнил с наивной реализацией (разница > 100x).
  • В README указаны инструкции по сборке и запуску тестов.
  • Код оформлен согласно PEP8 (python) и CUDA C++ best practices.