中文翻译暂不可用,显示俄语原文。
Что такое sliding window attention и зачем он в Mistral?
Краткий тезис
Sliding window attention (SWA) — это механизм, при котором каждый токен «видит» только фиксированное окно из предыдущих W токенов (в Mistral W = 4096). Это снижает вычислительную сложность с O(n²) до O(n·W) и позволяет обрабатывать последовательности произвольной длины без роста памяти на KV-кэш. В Mistral SWA используется в комбинации с Rolling Buffer Cache для эффективной инференции длинных контекстов (до 32k токенов) и служит основой для масштабирования модели на большие последовательности.
1. Термин: Sliding Window Attention (SWA)
Sliding window attention — это разновидность sparse attention, где каждый токен вычисляет внимание только к фиксированному числу предыдущих токенов (окну). Окно «скользит» по последовательности: для токена на позиции i доступны токены с позиций i-W+1 по i (или i-W по i-1, в зависимости от реализации). В Mistral окно равно 4096 токенам.
Зачем это нужно attention|Полное внимание (full attention) требует O(n²) операций и памяти, что делает обработку длинных последовательностей (более 8-16k токенов) непомерно дорогой. SWA даёт линейную сложность по длине последовательности, сохраняя способность модели улавливать локальные зависимости.
Важно SWA не отменяет дальние зависимости полностью — они передаются через многослойную архитектуру. Каждый слой расширяет эффективное поле зрения: после L слоёв токен может «увидеть» до L·W токенов назад (за счёт рекуррентной передачи информации через окна).
2. Как работает SWA в Mistral
Mistral использует SWA с окном 4096 и Rolling Buffer Cache для хранения KV-пар только в пределах окна. Рассмотрим детали:
- Прямой проход (forward): Для каждого токена вычисляется attention только к токенам в окне. Маска внимания — диагональная полоса шириной W.
- KV-кэш Вместо хранения всех KV-пар (что потребовало бы O(n·d) памяти), хранится только последние W пар. При генерации нового токена старые пары вытесняются (rolling buffer).
- Pre-fill (предзаполнение): Для длинных промптов (например, 32k токенов) используется chunking — промпт разбивается на чанки размером W, и каждый чанк обрабатывается с учётом предыдущего окна. Это позволяет эффективно заполнить KV-кэш без квадратичного роста.
Формально
Пусть последовательность длины n, окно W. Для позиции i:
- Attention scores: ( [text](/wiki/text){score}_{i,j} = \frac{q_i \cdot k_j}{\sqrt{d}} ) для j ∈ [max(0, i-W+1), i].
- Выход: ( o_i = \sum_{j} [text](/wiki/text){softmax}([text](/wiki/text){score}_{i,j}) v_j ).
Сложность: O(n·W·d) вместо O(n²·d). При W << n (типично W=4096, n может быть 32k) выигрыш значителен.
3. Зачем SWA в Mistral: ключевые мотивации
Mistral 7B была спроектирована для эффективного инференса на длинных последовательностях. SWA решает три основные проблемы:
| Проблема | Решение с SWA |
|---|---|
| Квадратичная сложность full attention | Линейная O(n·W) |
| Рост KV-кэша с длиной контекста | Rolling Buffer Cache фиксированного размера (W·d) |
| Ограничение длины контекста | Возможность обрабатывать последовательности любой длины (теоретически) за счёт многослойного распространения |
Практический эффект Mistral 7B поддерживает контекст до 32k токенов (с некоторыми оговорками) при значительно меньших затратах памяти и времени по сравнению с Llama 2 7B (full attention, 4k контекст). Это позволяет использовать Mistral в RAG-системах с большими документами, агентных сценариях с длинной историей и других задачах, требующих длинного контекста.
4. Сравнение SWA с другими механизмами внимания
| Механизм | Сложность | Память (KV-кэш) | Дальние зависимости | Примеры моделей |
|---|---|---|---|---|
| Full attention | O(n²) | O(n·d) | Прямые, любые | GPT-2, Llama 2, BERT |
| Sliding window (SWA) | O(n·W) | O(W·d) | Через слои (L·W) | Mistral, Longformer |
| Dilated attention | O(n·W) | O(W·d) | Разреженные, с шагом | Longformer, BigBird |
| Global + local attention | O(n·W + n·G) | O((W+G)·d) | Глобальные токены | BigBird, Longformer |
| Flash Attention | O(n²) (но быстрее) | O(n·d) (но меньше) | Полные | GPT-4, Llama 3 |
Ключевое отличие SWA фиксированное окно без глобальных токенов — простота реализации и предсказуемое потребление памяти. Недостаток — ограниченное прямое поле зрения, что может снижать качество на задачах, требующих точных дальних связей (например, поиск информации в конце документа).
5. Реализация в Mistral: детали
Mistral использует SWA с окном 4096 во всех слоях. Дополнительные оптимизации:
- Rolling Buffer Cache KV-кэш реализован как циклический буфер размера W. При генерации нового токена самый старый элемент заменяется. Это позволяет не выделять память под всю последовательность.
- Pre-fill and Chunking Для длинных промптов (например, 32k) промпт разбивается на чанки по 4096 токенов. Каждый чанк обрабатывается последовательно, при этом KV-кэш предыдущего чанка сохраняется (но только последние W токенов). Это даёт возможность «увидеть» весь контекст через многослойную передачу.
- Masking: Маска внимания — нижняя треугольная матрица с шириной полосы W. В реализации используется Flash Attention (опционально) для ускорения вычислений.
Пример кода (упрощённая реализация SWA на PyTorch):
import torch
import torch.nn.functional as F
def sliding_window_attention(q, k, v, window_size):
# q, k, v: (batch, heads, seq_len, dim)
B, H, T, D = q.shape
scores = torch.matmul(q, k.transpose(-2, -1)) / (D ** 0.5) # (B, H, T, T)
# Маска: True для позиций вне окна
mask = torch.ones(T, T, device=q.device).triu(diagonal=window_size).bool()
scores = scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn = F.softmax(scores, dim=-1)
out = torch.matmul(attn, v)
return out
6. Преимущества SWA в Mistral
- Линейная сложность Позволяет обрабатывать последовательности в 8-16 раз длиннее, чем full attention, при тех же вычислительных ресурсах.
- Фиксированная память KV-кэш не растёт с длиной контекста — это критично для серверного инференса и деплоя на edge-устройствах.
- Простота реализации Не требует сложных паттернов sparse attention (как BigBird) или дополнительных глобальных токенов.
- Совместимость с Flash Attention SWA может быть реализована через Flash Attention с соответствующей маской, что даёт дополнительное ускорение.
7. Ограничения и недостатки
- Ограниченное прямое поле зрения Токен не может напрямую обратить внимание на токен, находящийся дальше W позиций. Хотя многослойная архитектура частично компенсирует это, для задач, требующих точного сопоставления далёких позиций (например, coreference resolution через длинные расстояния), SWA может уступать full attention.
- Потеря информации при больших окнах Если окно слишком мало (например, 1024), модель может «забывать» ранние части контекста. Mistral выбрала 4096 как компромисс между эффективностью и качеством.
- Необходимость глубоких слоёв Для эффективного распространения информации требуется достаточное количество слоёв (в Mistral 7B — 32 слоя). Мелкие модели с SWA могут страдать от недостатка «дальнего зрения».
8. Связь с другими механизмами внимания
SWA — частный случай sparse attention. Другие подходы:
- Dilated sliding window Окно с шагом (dilation), как в Longformer. Позволяет увеличить эффективное поле зрения без роста W.
- Global attention Некоторые токены (например, [CLS]) имеют доступ ко всем токенам. Используется в BigBird и Longformer для задач классификации.
- Grouped-Query Attention (GQA): В Mistral также используется GQA (8 ключей на 32 головы запроса) для снижения размера KV-кэша. SWA и GQA ортогональны и могут комбинироваться.
9. Пет-проект для закрепления
Задача Реализовать простой трансформер с sliding window attention и сравнить его производительность с full attention на задаче генерации текста (например, обучение на небольшом корпусе).
Инструменты PyTorch, Hugging Face Transformers (для baseline), Weights & Biases (для логирования).
Шаги:
- Реализовать класс
SlidingWindowAttentionс параметромwindow_size. - Собрать небольшой трансформер (2-4 слоя, 4 головы, dim=256) с возможностью переключения между full и SWA.
- Обучить обе версии на датасете (например, WikiText-2) на задачу language modeling.
- Замерить:
- Время обучения на эпоху.
- Пиковое использование GPU памяти.
- Perplexity на валидации.
- Построить графики зависимости perplexity от длины контекста (например, 512, 1024, 2048) для обеих моделей.
Ожидаемый результат Вы увидите, что SWA потребляет значительно меньше памяти и быстрее обучается, но может показывать slightly худший perplexity на длинных контекстах (особенно если окно мало). Эксперимент наглядно демонстрирует trade-off между эффективностью и качеством.
10. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 280 | Архитектура Mistral 7B (общий обзор) |
| 282 | Rolling Buffer Cache в Mistral |
| 283 | Grouped-Query Attention (GQA) |
| 284 | Mistral 7B vs Llama 2 7B |
| 286 | Flash Attention и его роль |
| 287 | Sparse Attention: виды и применение |
Навигация
- Предыдущий: 280
- Следующий: 282
- Индекс: 00. Индекс разборов