Как работает YaRN (Yet another RoPE extensioN) для увеличения контекста?

Краткий тезис

YaRN — это метод расширения контекстного окна моделей на основе RoPE (Rotary Position Embedding), который комбинирует интерполяцию позиций с частотным масштабированием и температурным параметром. Он позволяет увеличить контекст Llama-2 с 4K до 32K–128K токенов всего за 1000 шагов дообучения, сохраняя качество на коротких последовательностях. Без YaRN модель не способна обобщать на позиции, выходящие за пределы обученного диапазона.


1. Термин: RoPE (Rotary Position Embedding)

RoPE — это метод кодирования позиций токенов в трансформерах, который вращает эмбеддинги в комплексной плоскости. Каждая позиция m получает уникальное вращение, зависящее от частот θ_i = base^(-2i/d), где d — размерность эмбеддинга, base — обычно 10000.

Ключевое свойство: RoPE позволяет модели естественно обобщать на относительные позиции, но не на абсолютные, выходящие за диапазон, на котором модель обучалась. Если модель обучена на контексте длины L_train, то при длине L_test > L_train частоты вращения для новых позиций становятся невиданными, и модель даёт мусор.

Проблема экстраполяции: RoPE не поддерживает экстраполяцию за пределы обученной длины. Даже небольшое превышение (например, 4K → 5K) резко ухудшает качество.


2. Базовая идея: Position Interpolation (PI)

Первый подход к расширению контекста — Position Interpolation (PI). Он просто сжимает все позиции в обученный диапазон:

  • Новая позиция m' = m / s, где s — коэффициент масштаба (например, s=8 для 4K → 32K).
  • Таким образом, все позиции от 0 до 32K отображаются в диапазон [0, 4K].

Недостаток PI: после интерполяции частоты вращения становятся слишком низкими, модель теряет способность различать близкие позиции (разрешение падает). Требуется дообучение на длинных контекстах, но даже после него качество на коротких контекстах может ухудшиться.


3. NTK-aware scaling

NTK-aware scaling (Neural Tangent Kernel) — улучшение PI, основанное на наблюдении, что высокочастотные компоненты RoPE отвечают за локальные детали, а низкочастотные — за глобальную структуру. Вместо равномерного сжатия всех частот, NTK-метод масштабирует только низкие частоты, оставляя высокие почти неизменными:

  • Для каждого θ_i применяется свой коэффициент: θ_i' = θ_i * s^(2i/(d-2)).
  • Высокие частоты (большие i) почти не меняются, низкие — сжимаются.

Это сохраняет разрешение на малых расстояниях и позволяет модели лучше обобщать. Однако NTK всё равно требует дообучения и не идеален для очень больших масштабов (s > 16).


4. Dynamic NTK

Dynamic NTK — адаптивная версия NTK, где коэффициент масштаба s вычисляется динамически в зависимости от текущей длины последовательности:

s_dynamic = max(1, L_test / L_train)

Если длина меньше или равна обученной, масштаб равен 1 (никаких изменений). Если больше — применяется NTK-scaling с этим s. Это позволяет модели работать на любой длине без дообучения, но качество на очень длинных контекстах (s > 32) всё ещё падает.


5. YaRN: комбинация PI + NTK + температурный параметр

YaRN (Yet another RoPE extensioN) объединяет лучшие черты PI и NTK, добавляя ключевой ингредиент — температурный параметр t.

5.1 Растяжение позиций

YaRN, как и PI, интерполирует позиции: m' = m / s. Но в отличие от PI, он не применяет интерполяцию ко всем частотам равномерно.

5.2 Частотное масштабирование (NTK-стиль)

YaRN использует NTK-подобное масштабирование частот, но с модификацией:

θ_i' = θ_i * (s * t)^(2i/(d-2))

Здесь t — температурный параметр, обычно близкий к 1 (например, 0.994 для s=8). Он слегка корректирует масштаб, чтобы сбалансировать интерполяцию и сохранение высоких частот.

5.3 Температурный параметр

Температура t — это гиперпараметр, который контролирует, насколько сильно сжимаются высокие частоты. При t=1 YaRN эквивалентен NTK-scaling. Уменьшение t (например, до 0.99) сильнее сжимает высокие частоты, что улучшает стабильность на очень длинных контекстах, но может снизить разрешение на коротких.

Эмпирическое правило: для масштаба s оптимальное t лежит в диапазоне [0.95, 1.0] и подбирается экспериментально.

5.4 Итоговая формула

Для каждой позиции m и пары измерений (2i, 2i+1):

θ_i' = θ_i * (s * t)^(2i/(d-2))
m' = m / s
cos(m' * θ_i'), sin(m' * θ_i') — используются в RoPE

На практике YaRN реализуется путём замены cos и sin в коде RoPE на предвычисленные значения с новыми частотами.


6. Результаты и дообучение

Авторы YaRN протестировали метод на Llama-2 7B (обучена на контексте 4K токенов). Результаты:

Масштаб sЦелевая длинаPerplexity (длинный контекст)Perplexity (короткий контекст)Шагов дообучения
832K3.23.1 (без потерь)1000
1664K3.53.2 (небольшое ухудшение)1000
32128K4.03.5 (умеренное ухудшение)1000

Ключевые выводы:

  • YaRN позволяет расширить контекст в 8–32 раза всего за 1000 шагов дообучения (на одном GPU A100).
  • Качество на коротких контекстах практически не падает (для s ≤ 16).
  • Без дообучения YaRN работает хуже, чем Dynamic NTK, но после 1000 шагов превосходит все другие методы.

7. Сравнение методов расширения контекста

МетодИнтерполяция позицийМасштабирование частотТемператураДообучениеКачество на короткихМакс. масштаб
PIДа (равномерно)НетНетТребуетсяПадает8x
NTKНетДа (неравномерно)НетТребуетсяХорошее16x
Dynamic NTKНетДа (адаптивно)НетНе нужноОтличное32x (с падением)
YaRNДа (равномерно)Да (NTK-стиль)Да1000 шаговОтличное32x+

8. Практическое применение в RAG и AI-агентах

В контексте Agentic RAG длинные контексты критичны:

  • Агент может обрабатывать историю диалога из десятков тысяч токенов.
  • Retrieval может возвращать много документов, которые нужно поместить в контекст.
  • YaRN позволяет использовать одну модель для разных масштабов без переобучения под каждый кейс.

Пример: Llama-2 7B с YaRN (s=8) может обработать 32K контекста, что достаточно для 10–15 средних документов (по 2K токенов) плюс запрос.


9. Ограничения и нюансы

  • YaRN требует дообучения (хотя и минимального). Без дообучения лучше использовать Dynamic NTK.
  • Температурный параметр t нужно подбирать под конкретную модель и масштаб.
  • Для очень больших масштабов (s > 32) качество всё равно падает — возможно, нужны архитектурные изменения (например, ALiBi или параллельные контексты).
  • YaRN не решает проблему "lost in the middle" — модель всё равно хуже работает с информацией в середине длинного контекста.

Пет-проект для закрепления

Задача: Реализовать YaRN для небольшой модели (например, GPT-2 или TinyLlama) и проверить, как меняется perplexity на длинных последовательностях.

Инструменты: Python, Hugging Face Transformers, PyTorch, Datasets (для загрузки длинных текстов, например, PG-19).

Шаги:

  1. Загрузить предобученную модель с RoPE (например, TinyLlama/TinyLlama-1.1B-Chat-v1.0).
  2. Извлечь конфиг RoPE (частоты θ_i).
  3. Реализовать функцию yarn_rope:
    • Принимает s (scale) и t (temperature).
    • Вычисляет новые частоты: theta_i * (s * t)^(2i/(d-2)).
    • Генерирует cos/sin для позиций до max_len * s.
  4. Заменить стандартный RoPE в модели на YaRN (через monkey-patch forward метода).
  5. Собрать датасет длинных текстов (например, первые 50K токенов из PG-19).
  6. Дообучить модель на 1000 шагов с s=4 (расширение с 2K до 8K).
  7. Оценить perplexity на тестовых последовательностях длины 2K, 4K, 8K.
  8. Сравнить с baseline (без YaRN) и с Dynamic NTK.

Ожидаемый результат: Perplexity на длинных контекстах (8K) значительно снизится после дообучения с YaRN, а на коротких (2K) останется на уровне baseline.


Связь с другими вопросами

ВопросТема
642Как работает RoPE (Rotary Position Embedding)?
644Чем YaRN отличается от Position Interpolation и NTK-aware scaling?
645Как работает NTK-aware scaling для расширения контекста?
646Что такое Dynamic NTK и когда его использовать?
647Как дообучать модель на длинные контексты (fine-tuning for long context)?
648Какие ещё методы расширения контекста существуют (ALiBi, xPos, etc.)?

Навигация