中文翻译暂不可用,显示俄语原文。
Hyena: как заменить attention на свертки, сохранив качество?
Краткий тезис
Hyena — это архитектура, заменяющая механизм attention в трансформерах на long convolutional filters (filters|длинные сверточные фильтры). Она достигает линейно-логарифмической сложности O(n log n) за счёт использования БПФ (быстрого преобразования Фурье) для вычисления сверток длиной до 128k токенов. На задачах с длинным контекстом Hyena показывает качество, сопоставимое с трансформером, но не поддерживает произвольный доступ к прошлым токенам, что ограничивает её применение в задачах, требующих точного recall.
1. Проблема attention: квадратичная сложность
Механизм self-attention (внимание к самому себе) в трансформерах имеет сложность O(n²) по длине последовательности n. Это означает, что при удвоении длины контекста время и память растут в четыре раза. Для длинных контекстов (более 8k токенов) это становится непрактичным: даже на современных GPU модель с attention не может обработать последовательность из 100k токенов из-за ограничений памяти. Эта проблема стимулировала поиск альтернатив, таких как linear attention, state space models и сверточные архитектуры.
2. Идея Hyena: замена attention на long convolutions
Hyena (полное название — Hyena Hierarchy: Towards Larger Convolutional Language Models, 2023) предлагает заменить матрицу attention на комбинацию длинных сверточных фильтров и поэлементных умножений. Основная операция — Hyena Operator, который состоит из трёх этапов:
- Проекция входной последовательности через несколько фильтров разной длины (обычно 2–4 фильтра на слой).
- Элемент-вайз умножение (pointwise multiplication) между проекциями — аналог gating (стробирования), где одни фильтры управляют информационным потоком других.
- Обратная проекция для получения выходного представления.
Математически для одного слоя с двумя фильтрами (u и v) операция выглядит так:
y = (u * x) ⊙ (v * x)
где * — свертка, ⊙ — поэлементное умножение. На практике используется больше фильтров и иерархическая структура.
3. Long convolutional filters и FFT
Длина фильтров в Hyena может достигать 128k токенов. Прямое вычисление свертки такой длины имело бы сложность O(n·L), где L — длина фильтра, что неприемлемо. Hyena использует быстрое преобразование Фурье (FFT), чтобы вычислить свертку за O(n log n). Алгоритм:
- Дополняем фильтр и входную последовательность нулями до ближайшей степени двойки.
- Вычисляем дискретное преобразование Фурье (DFT) через FFT для обеих последовательностей.
- Поэлементно перемножаем спектры.
- Выполняем обратное FFT.
Фильтры обучаются либо в частотной области (как комплексные коэффициенты), либо параметризуются как импульсные характеристики (impulse responses) во временной области с последующим FFT.
4. Архитектура Hyena Hierarchy
Hyena состоит из нескольких слоёв (Hyena operators), организованных в иерархию. Каждый слой использует разное количество фильтров и разные длины. Типичная конфигурация:
- Входная проекция: линейный слой, увеличивающий размерность (например, с d_model до 2·d_model).
- Несколько параллельных сверток с разными фильтрами (обычно 2–4).
- Поэлементное умножение результатов в определённом порядке (например, (ux) ⊙ (vx) ⊙ (w*x)).
- Обратная проекция: линейный слой, возвращающий исходную размерность.
Иерархия означает, что выход одного Hyena operator подаётся на вход следующего, формируя глубокую сеть. Это напоминает gating mechanism (механизм стробирования), где одни фильтры управляют информационным потоком других.
5. Сравнение Hyena и Attention
| Характеристика | Attention | Hyena |
|---|---|---|
| Сложность | O(n²) | O(n log n) |
| Длина контекста | Ограничена памятью (обычно до 8k) | До 128k и более |
| Произвольный доступ к прошлым токенам | Да (каждый токен может обращаться к любому) | Нет (только через свертку, фиксированное окно) |
| Качество на long-range tasks | Высокое | Сопоставимое |
| Обучение | Стабильно | Требует специальной инициализации фильтров |
| Параллелизм | Высокий (матричные умножения) | Высокий (FFT хорошо параллелится) |
| Память | O(n²) для матрицы attention | O(n) для хранения фильтров |
6. Результаты экспериментов
Hyena показала качество на уровне трансформера на бенчмарках LRA (Long Range Arena) и SCROLLS. Например:
- На задаче ListOps (классификация вложенных операций с длинными последовательностями) Hyena превзошла трансформер на 2–3% точности.
- На Pathfinder (поиск связности в изображениях) трансформер остаётся лучше из-за необходимости точного recall позиций.
- На Text Classification (длинные документы) Hyena достигает сопоставимой точности при значительно меньшем времени обучения.
Важно: Hyena не превосходит трансформер на всех задачах, но для long-range задач с гладкими зависимостями она работает отлично.
7. Ограничения Hyena
- Нет произвольного доступа: свертка обрабатывает последовательность как сигнал, не позволяя токену «заглянуть» в произвольное место. Это делает Hyena менее подходящей для задач, где важна точная идентификация позиции (например, извлечение информации, question answering с точными span'ами).
- Инициализация фильтров: фильтры нужно инициализировать специальным образом (например, как экспоненциально затухающие импульсы), иначе обучение может быть нестабильным. Стандартная инициализация (например, Xavier) приводит к взрыву градиентов.
- Не подходит для задач с короткими контекстами: на коротких последовательностях (менее 512 токенов) overhead от FFT может быть больше, чем у attention, и качество может уступать.
- Сложность реализации: требуется аккуратное управление длинами фильтров, паддингом и FFT, что усложняет отладку.
8. Применение Hyena в RAG
В контексте RAG (Retrieval-Augmented Generation) Hyena может быть полезна для обработки длинных документов (например, целых книг) в качестве backbone для энкодера. Однако retrieval (поиск релевантных чанков) всё равно необходим, так как Hyena не решает проблему выбора релевантной информации — она просто эффективно обрабатывает длинные последовательности. Комбинация Hyena + retrieval может дать выигрыш в скорости при работе с большими контекстами (например, энкодер на Hyena обрабатывает весь документ, а retrieval выбирает нужные чанки для генерации).
9. Реализация Hyena на Python (концептуально)
import torch
import torch.nn as nn
import torch.nn.functional as F
class HyenaLayer(nn.Module):
def __init__(self, d_model, num_filters=2, max_filter_len=1024):
super().__init__()
self.num_filters = num_filters
# Фильтры: обучаемые импульсные характеристики
self.filters = nn.Parameter(torch.randn(num_filters, max_filter_len))
# Проекции
self.proj_in = nn.Linear(d_model, d_model * num_filters)
self.proj_out = nn.Linear(d_model, d_model)
# Инициализация фильтров (экспоненциальное затухание)
with torch.no_grad():
t = torch.arange(max_filter_len).float()
for i in range(num_filters):
alpha = 0.1 * (i + 1)
self.filters[i] = torch.exp(-alpha * t)
def forward(self, x):
# x: (batch, seq_len, d_model)
batch, seq_len, d_model = x.shape
# Проекция
proj = self.proj_in(x) # (batch, seq_len, d_model * num_filters)
proj = proj.view(batch, seq_len, self.num_filters, d_model)
# Свертка через FFT для каждого фильтра
out = 0
for i in range(self.num_filters):
# Дополнение до степени двойки
L = seq_len
n_fft = 1 << (L - 1).bit_length() # ближайшая степень двойки
# Фильтр дополняем нулями
f_pad = F.pad(self.filters[i], (0, n_fft - self.filters.size(1)))
# FFT
x_fft = torch.fft.rfft(proj[..., i, :], n=n_fft)
f_fft = torch.fft.rfft(f_pad)
# Свертка в частотной области
conv = torch.fft.irfft(x_fft * f_fft, n=n_fft)[:, :L, :]
# Элемент-вайз умножение с другой проекцией (gating)
gate = proj[..., (i+1) % self.num_filters, :]
out = out + conv * gate
# Обратная проекция
out = self.proj_out(out)
return out
10. Связь с другими подходами
Hyena относится к семейству State Space Models (SSM), таких как S4 и Mamba. Все они заменяют attention на рекуррентные или сверточные операции с линейной сложностью. Hyena можно рассматривать как частный случай SSM с длинными фильтрами, где рекуррентность заменена явной сверткой. В отличие от Mamba, Hyena не использует вход-зависимую рекуррентность (input-dependent recurrence), что делает её менее гибкой, но проще в реализации и более стабильной при обучении.
11. Как Hyena сохраняет качество: интуиция
Attention позволяет каждому токену «смотреть» на любой другой токен с разными весами. Свертка с длинным фильтром может аппроксимировать это, если фильтр имеет достаточную длину и выразительность. Hyena использует несколько фильтров и элемент-вайз умножение, что создаёт нелинейное взаимодействие, подобное attention. Теоретически показано (в оригинальной статье), что Hyena может представлять любую функцию, которую может attention, но с меньшей сложностью, если функция обладает гладкостью (smoothness) — то есть зависимости между токенами не слишком резкие. Для задач, где важны точные позиционные соответствия (например, Pathfinder), гладкость нарушается, и Hyena уступает.
12. Практические аспекты: инициализация и обучение
Фильтры в Hyena инициализируются как экспоненциально затухающие импульсы (exponential decay) или синусоидальные модуляции. Это важно для стабильности градиентов: резкие фильтры приводят к высокочастотным компонентам, которые плохо обучаются. Также используется LayerNorm (нормализация по слоям) и residual connections (остаточные связи). Обучение Hyena может быть чувствительно к скорости обучения — рекомендуется использовать warmup (постепенное увеличение learning rate) и меньший learning rate, чем для трансформера (например, 1e-4 вместо 3e-4). Для длинных последовательностей (>16k) рекомендуется использовать mixed precision training (обучение со смешанной точностью) для экономии памяти.
Пет-проект для закрепления
Задача Реализовать слой Hyena и сравнить его с attention на задаче классификации длинных последовательностей (например, датасет IMDB с длинными отзывами или синтетические данные с длинными паттернами).
Инструменты PyTorch, NumPy, библиотека для FFT (torch.fft), датасет (например, Long Range Arena или собственный генератор).
Шаги:
- Реализовать класс
HyenaLayerс параметрами:d_model,num_filters,max_filter_len. - Инициализировать фильтры как экспоненциально затухающие импульсы:
f[t] = exp(-alpha * t). - Реализовать forward: проекция, свертка через FFT, элемент-вайз умножение, обратная проекция.
- Реализовать аналогичный
AttentionLayer(один слой multi-head attention) для сравнения. - Обучить обе модели на задаче бинарной классификации (например, положительный/отрицательный отзыв) с длиной последовательности до 8192 токенов.
- Сравнить:
- Время обучения на эпоху.
- Потребление GPU памяти.
- Точность (accuracy) на валидации.
- Влияние длины последовательности на скорость.
Ожидаемый результат Hyena должна быть быстрее на длинных последовательностях (например, >4096 токенов) и потреблять меньше памяти, при этом точность должна быть сопоставима с attention (разница не более 1–2%). На коротких последовательностях attention может быть быстрее из-за overhead FFT.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 710 | Attention: механизм и проблемы |
| 711 | Long context: методы расширения контекста |
| 712 | S4: State Space Models |
| 713 | Mamba: альтернатива attention |
| 714 | Linear Attention: Performer, Linformer |
| 716 | FlashAttention: оптимизация attention |
Навигация
- Предыдущий: 714
- Следующий: 716
- Индекс: 00. Индекс разборов