Что такое sliding window attention и зачем он в Mistral?

Краткий тезис

Sliding window attention (SWA) — это механизм, при котором каждый токен «видит» только фиксированное окно из предыдущих W токенов (в Mistral W = 4096). Это снижает вычислительную сложность с O(n²) до O(n·W) и позволяет обрабатывать последовательности произвольной длины без роста памяти на KV-кэш. В Mistral SWA используется в комбинации с Rolling Buffer Cache для эффективной инференции длинных контекстов (до 32k токенов) и служит основой для масштабирования модели на большие последовательности.


1. Термин: Sliding Window Attention (SWA)

Sliding window attention — это разновидность sparse attention, где каждый токен вычисляет внимание только к фиксированному числу предыдущих токенов (окну). Окно «скользит» по последовательности: для токена на позиции i доступны токены с позиций i-W+1 по i (или i-W по i-1, в зависимости от реализации). В Mistral окно равно 4096 токенам.

Зачем это нужно attention|Полное внимание (full attention) требует O(n²) операций и памяти, что делает обработку длинных последовательностей (более 8-16k токенов) непомерно дорогой. SWA даёт линейную сложность по длине последовательности, сохраняя способность модели улавливать локальные зависимости.

Важно SWA не отменяет дальние зависимости полностью — они передаются через многослойную архитектуру. Каждый слой расширяет эффективное поле зрения: после L слоёв токен может «увидеть» до L·W токенов назад (за счёт рекуррентной передачи информации через окна).


2. Как работает SWA в Mistral

Mistral использует SWA с окном 4096 и Rolling Buffer Cache для хранения KV-пар только в пределах окна. Рассмотрим детали:

  • Прямой проход (forward): Для каждого токена вычисляется attention только к токенам в окне. Маска внимания — диагональная полоса шириной W.
  • KV-кэш Вместо хранения всех KV-пар (что потребовало бы O(n·d) памяти), хранится только последние W пар. При генерации нового токена старые пары вытесняются (rolling buffer).
  • Pre-fill (предзаполнение): Для длинных промптов (например, 32k токенов) используется chunking — промпт разбивается на чанки размером W, и каждый чанк обрабатывается с учётом предыдущего окна. Это позволяет эффективно заполнить KV-кэш без квадратичного роста.

Формально
Пусть последовательность длины n, окно W. Для позиции i:

  • Attention scores: ( [text](/wiki/text){score}_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ) для j ∈ [max(0, i-W+1), i].
  • Выход: ( o_i = \sum_{j} [text](/wiki/text){softmax}([text](/wiki/text){score}_{i,j}) v_j ).

Сложность: O(n·W·d) вместо O(n²·d). При W << n (типично W=4096, n может быть 32k) выигрыш значителен.


3. Зачем SWA в Mistral: ключевые мотивации

Mistral 7B была спроектирована для эффективного инференса на длинных последовательностях. SWA решает три основные проблемы:

ПроблемаРешение с SWA
Квадратичная сложность full attentionЛинейная O(n·W)
Рост KV-кэша с длиной контекстаRolling Buffer Cache фиксированного размера (W·d)
Ограничение длины контекстаВозможность обрабатывать последовательности любой длины (теоретически) за счёт многослойного распространения

Практический эффект Mistral 7B поддерживает контекст до 32k токенов (с некоторыми оговорками) при значительно меньших затратах памяти и времени по сравнению с Llama 2 7B (full attention, 4k контекст). Это позволяет использовать Mistral в RAG-системах с большими документами, агентных сценариях с длинной историей и других задачах, требующих длинного контекста.


4. Сравнение SWA с другими механизмами внимания

МеханизмСложностьПамять (KV-кэш)Дальние зависимостиПримеры моделей
Full attentionO(n²)O(n·d)Прямые, любыеGPT-2, Llama 2, BERT
Sliding window (SWA)O(n·W)O(W·d)Через слои (L·W)Mistral, Longformer
Dilated attentionO(n·W)O(W·d)Разреженные, с шагомLongformer, BigBird
Global + local attentionO(n·W + n·G)O((W+G)·d)Глобальные токеныBigBird, Longformer
Flash AttentionO(n²) (но быстрее)O(n·d) (но меньше)ПолныеGPT-4, Llama 3

Ключевое отличие SWA фиксированное окно без глобальных токенов — простота реализации и предсказуемое потребление памяти. Недостаток — ограниченное прямое поле зрения, что может снижать качество на задачах, требующих точных дальних связей (например, поиск информации в конце документа).


5. Реализация в Mistral: детали

Mistral использует SWA с окном 4096 во всех слоях. Дополнительные оптимизации:

  • Rolling Buffer Cache KV-кэш реализован как циклический буфер размера W. При генерации нового токена самый старый элемент заменяется. Это позволяет не выделять память под всю последовательность.
  • Pre-fill and Chunking Для длинных промптов (например, 32k) промпт разбивается на чанки по 4096 токенов. Каждый чанк обрабатывается последовательно, при этом KV-кэш предыдущего чанка сохраняется (но только последние W токенов). Это даёт возможность «увидеть» весь контекст через многослойную передачу.
  • Masking: Маска внимания — нижняя треугольная матрица с шириной полосы W. В реализации используется Flash Attention (опционально) для ускорения вычислений.

Пример кода (упрощённая реализация SWA на PyTorch):

import torch
import torch.nn.functional as F

def sliding_window_attention(q, k, v, window_size):
    # q, k, v: (batch, heads, seq_len, dim)
    B, H, T, D = q.shape
    scores = torch.matmul(q, k.transpose(-2, -1)) / (D ** 0.5)  # (B, H, T, T)
    # Маска: True для позиций вне окна
    mask = torch.ones(T, T, device=q.device).triu(diagonal=window_size).bool()
    scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
    attn = F.softmax(scores, dim=-1)
    out = torch.matmul(attn, v)
    return out

6. Преимущества SWA в Mistral

  • Линейная сложность Позволяет обрабатывать последовательности в 8-16 раз длиннее, чем full attention, при тех же вычислительных ресурсах.
  • Фиксированная память KV-кэш не растёт с длиной контекста — это критично для серверного инференса и деплоя на edge-устройствах.
  • Простота реализации Не требует сложных паттернов sparse attention (как BigBird) или дополнительных глобальных токенов.
  • Совместимость с Flash Attention SWA может быть реализована через Flash Attention с соответствующей маской, что даёт дополнительное ускорение.

7. Ограничения и недостатки

  • Ограниченное прямое поле зрения Токен не может напрямую обратить внимание на токен, находящийся дальше W позиций. Хотя многослойная архитектура частично компенсирует это, для задач, требующих точного сопоставления далёких позиций (например, coreference resolution через длинные расстояния), SWA может уступать full attention.
  • Потеря информации при больших окнах Если окно слишком мало (например, 1024), модель может «забывать» ранние части контекста. Mistral выбрала 4096 как компромисс между эффективностью и качеством.
  • Необходимость глубоких слоёв Для эффективного распространения информации требуется достаточное количество слоёв (в Mistral 7B — 32 слоя). Мелкие модели с SWA могут страдать от недостатка «дальнего зрения».

8. Связь с другими механизмами внимания

SWA — частный случай sparse attention. Другие подходы:

  • Dilated sliding window Окно с шагом (dilation), как в Longformer. Позволяет увеличить эффективное поле зрения без роста W.
  • Global attention Некоторые токены (например, [CLS]) имеют доступ ко всем токенам. Используется в BigBird и Longformer для задач классификации.
  • Grouped-Query Attention (GQA): В Mistral также используется GQA (8 ключей на 32 головы запроса) для снижения размера KV-кэша. SWA и GQA ортогональны и могут комбинироваться.

9. Пет-проект для закрепления

Задача Реализовать простой трансформер с sliding window attention и сравнить его производительность с full attention на задаче генерации текста (например, обучение на небольшом корпусе).

Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (для логирования).

Шаги:

  1. Реализовать класс SlidingWindowAttention с параметром window_size.
  2. Собрать небольшой трансформер (2-4 слоя, 4 головы, dim=256) с возможностью переключения между full и SWA.
  3. Обучить обе версии на датасете (например, WikiText-2) на задачу language modeling.
  4. Замерить:
    • Время обучения на эпоху.
    • Пиковое использование GPU памяти.
    • Perplexity на валидации.
  5. Построить графики зависимости perplexity от длины контекста (например, 512, 1024, 2048) для обеих моделей.

Ожидаемый результат Вы увидите, что SWA потребляет значительно меньше памяти и быстрее обучается, но может показывать slightly худший perplexity на длинных контекстах (особенно если окно мало). Эксперимент наглядно демонстрирует trade-off между эффективностью и качеством.


10. Связь с другими вопросами

ВопросТема
280Архитектура Mistral 7B (общий обзор)
282Rolling Buffer Cache в Mistral
283Grouped-Query Attention (GQA)
284Mistral 7B vs Llama 2 7B
286Flash Attention и его роль
287Sparse Attention: виды и применение

Навигация