Что такое logit lens и как он помогает понимать внутренние представления?
Краткий тезис
Logit lens — это техника интерпретации, которая проецирует скрытые состояния (hidden states) промежуточных слоёв LLM на выходное пространство токенов (через матрицу unembedding). Она показывает, как token предсказание следующего токена формируется по мере прохождения слоёв: на ранних слоях распределение размыто, на средних появляются ключевые концепции, а на поздних — окончательное уточнение. Logit lens позволяет заглянуть в «мышление» модели без дополнительного обучения.
1. Что такое logit lens?
Logit lens (буквально «линза логитов») — метод, впервые описанный в контексте отладки трансформеров (например, в статье "Transformer Debugging" Ноама Шазира). Идея: взять state|скрытое состояние после любого слоя, пропустить его через LM head (матрицу unembedding + bias) и получить логиты для всех токенов словаря. После softmax эти логиты интерпретируются как вероятности, которые модель «думает» на данном слое.
Ключевая особенность: не требуется дообучение — используется уже существующая матрица unembedding, которая обычно применяется только на последнем слое. Это делает logit lens простым и быстрым инструментом.
2. Как работает logit lens?
Шаги:
- Подать на вход модели последовательность токенов и получить скрытые состояния (hidden states) для каждого слоя. Обычно берут состояния после residual stream (выход слоя до следующего attention/FFN).
- Для каждого слоя $l$ взять state|скрытое состояние $h_l$ (размерность $d_{model}$).
- Умножить $h_l$ на матрицу unembedding $W_U$ (размер $d_{model} \times V$, где $V$ — размер словаря) и прибавить bias $b_U$ (если есть): $[text](/wiki/text){logits}_l = h_l \cdot W_U + b_U$.
- Применить softmax: $p_l = [text](/wiki/text){softmax}([text](/wiki/text){logits}_l)$.
- Интерпретировать $p_l$ как распределение вероятностей следующего токена на слое $l$.
Пример (упрощённый код на Python с библиотекой transformers):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "The capital of France is"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_hidden_states=True)
hidden_states = outputs.hidden_states # tuple of (batch, seq_len, d_model) for each layer
# Берём скрытое состояние последнего токена (предсказываем следующий)
last_token_hidden = [hs[:, -1, :] for hs in hidden_states] # list of (batch, d_model)
# Применяем unembedding (lm_head)
lm_head = model.lm_head
bias = model.lm_head.bias if hasattr(model.lm_head, 'bias') else None
for layer_idx, h in enumerate(last_token_hidden):
logits = lm_head(h)
probs = torch.softmax(logits, dim=-1)
top5 = torch.topk(probs, 5)
print(f"Layer {layer_idx}: top tokens: {[tokenizer.decode(t) for t in top5.indices[0]]}")
Результат покажет, как топ-токены меняются от слоя к слою: на ранних — шум, на средних — «Paris», на поздних — уточнение.
3. Зачем нужен logit lens?
Logit lens помогает ответить на вопросы:
- На каком слое модель «решает», какой токен предсказать? Например, в фактологических задачах ответ может появиться уже на средних слоях, а поздние слои лишь корректируют стиль или грамматику.
- Какие концепции модель «видит» на разных глубинах? Ранние слои часто отвечают за синтаксис, средние — за семантику, поздние — за контекст и согласование.
- Где происходит «исправление» ошибок Если на раннем слое модель предсказывает неверный токен, а на позднем — правильный, можно понять, какой слой внёс коррекцию.
Это особенно полезно для mechanistic interpretability — исследования внутренних механизмов модели.
4. Связь с архитектурой Transformer
В трансформере информация течёт через residual stream — сумму выходов всех слоёв. Каждый слой (attention + FFN) добавляет свой вклад. Logit lens показывает, как residual stream постепенно накапливает информацию, необходимую для предсказания.
- Attention отвечает за перемешивание информации между токенами.
- FFN — за извлечение фактов и преобразование признаков.
- Logit lens после attention или FFN может показать, какой из компонентов даёт наибольший вклад в финальное предсказание.
5. Ограничения и альтернативы
| Метод | Описание | Плюсы | Минусы |
|---|---|---|---|
| Logit lens | Проекция скрытых состояний через unembedding | Простота, без обучения | Шум на ранних слоях; unembedding обучен только для последнего слоя |
| Tuned lens | Обучается линейная проекция для каждого слоя | Более точные вероятности | Требует дополнительного обучения |
| Probing | Обучается классификатор поверх скрытых состояний | Гибкость (можно предсказывать любые свойства) | Нужны размеченные данные; может переобучаться |
| Activation patching | Замена активаций на одном слое на другой | Позволяет выявить причинно-следственные связи | Вычислительно дорого |
Основное ограничение logit lens: матрица unembedding оптимизирована для последнего слоя, поэтому для ранних слоёв распределение может быть некалиброванным. Тем не менее, качественные паттерны (появление правильного токена) обычно видны.
6. Пример интерпретации с помощью logit lens
Рассмотрим пример с моделью GPT-2 и предложением "The Eiffel Tower is located in". Ожидаемый следующий токен — « Paris».
| Слой | Топ-3 токена (вероятность) | Комментарий |
|---|---|---|
| 0 | " the" (0.12), " a" (0.10), " in" (0.08) | Шум, общие слова |
| 5 | " Paris" (0.25), " France" (0.15), " the" (0.10) | Появляется правильный ответ |
| 10 | " Paris" (0.60), " France" (0.20), " city" (0.05) | Уверенность растёт |
| 12 (последний) | " Paris" (0.85), " France" (0.10), " the" (0.02) | Финальное предсказание |
Видно, что модель «знает» ответ уже на 5-м слое, а последующие слои лишь усиливают уверенность.
7. Пет-проект для закрепления
Задача: Написать скрипт, который для заданного текста визуализирует, как меняется топ-5 токенов на каждом слое модели (например, GPT-2 или Llama-2-7B).
Инструменты: Python, PyTorch, transformers, matplotlib (или plotly для интерактива).
Шаги:
- Загрузить модель с
output_hidden_states=True. - Для каждого слоя получить скрытое состояние последнего токена.
- Применить
lm_headи softmax. - Построить тепловую карту (слои × токены) или анимацию изменения вероятности целевого токена.
- Протестировать на нескольких примерах: фактологический запрос, грамматическая конструкция, творческий текст.
Ожидаемый результат: График, на котором видно, на каком слое целевой токен входит в топ-5, и как его вероятность растёт. Это наглядно демонстрирует, как модель «думает» по слоям.
8. Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 290 | Что такое residual stream в трансформере? |
| 291 | Как работает attention и зачем он нужен? |
| 294 | Что такое probing и как он используется для анализа LLM? |
| 296 | Как работает activation patching? |
| 300 | Какие методы интерпретации LLM вы знаете? |
| 310 | Что такое mechanistic interpretability? |
Навигация
- Предыдущий: 294
- Следующий: 296
- Индекс: 00. Индекс разборов