Что такое logit lens и как он помогает понимать внутренние представления?

Краткий тезис

Logit lens — это техника интерпретации, которая проецирует скрытые состояния (hidden states) промежуточных слоёв LLM на выходное пространство токенов (через матрицу unembedding). Она показывает, как token предсказание следующего токена формируется по мере прохождения слоёв: на ранних слоях распределение размыто, на средних появляются ключевые концепции, а на поздних — окончательное уточнение. Logit lens позволяет заглянуть в «мышление» модели без дополнительного обучения.


1. Что такое logit lens?

Logit lens (буквально «линза логитов») — метод, впервые описанный в контексте отладки трансформеров (например, в статье "Transformer Debugging" Ноама Шазира). Идея: взять state|скрытое состояние после любого слоя, пропустить его через LM head (матрицу unembedding + bias) и получить логиты для всех токенов словаря. После softmax эти логиты интерпретируются как вероятности, которые модель «думает» на данном слое.

Ключевая особенность: не требуется дообучение — используется уже существующая матрица unembedding, которая обычно применяется только на последнем слое. Это делает logit lens простым и быстрым инструментом.


2. Как работает logit lens?

Шаги:

  1. Подать на вход модели последовательность токенов и получить скрытые состояния (hidden states) для каждого слоя. Обычно берут состояния после residual stream (выход слоя до следующего attention/FFN).
  2. Для каждого слоя $l$ взять state|скрытое состояние $h_l$ (размерность $d_{model}$).
  3. Умножить $h_l$ на матрицу unembedding $W_U$ (размер $d_{model} \times V$, где $V$ — размер словаря) и прибавить bias $b_U$ (если есть): $[text](/wiki/text){logits}_l = h_l \cdot W_U + b_U$.
  4. Применить softmax: $p_l = [text](/wiki/text){softmax}([text](/wiki/text){logits}_l)$.
  5. Интерпретировать $p_l$ как распределение вероятностей следующего токена на слое $l$.

Пример (упрощённый код на Python с библиотекой transformers):

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_name)

input_text = "The capital of France is"
inputs = tokenizer(input_text, return_tensors="pt")
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    hidden_states = outputs.hidden_states  # tuple of (batch, seq_len, d_model) for each layer

# Берём скрытое состояние последнего токена (предсказываем следующий)
last_token_hidden = [hs[:, -1, :] for hs in hidden_states]  # list of (batch, d_model)

# Применяем unembedding (lm_head)
lm_head = model.lm_head
bias = model.lm_head.bias if hasattr(model.lm_head, 'bias') else None

for layer_idx, h in enumerate(last_token_hidden):
    logits = lm_head(h)
    probs = torch.softmax(logits, dim=-1)
    top5 = torch.topk(probs, 5)
    print(f"Layer {layer_idx}: top tokens: {[tokenizer.decode(t) for t in top5.indices[0]]}")

Результат покажет, как топ-токены меняются от слоя к слою: на ранних — шум, на средних — «Paris», на поздних — уточнение.


3. Зачем нужен logit lens?

Logit lens помогает ответить на вопросы:

  • На каком слое модель «решает», какой токен предсказать? Например, в фактологических задачах ответ может появиться уже на средних слоях, а поздние слои лишь корректируют стиль или грамматику.
  • Какие концепции модель «видит» на разных глубинах? Ранние слои часто отвечают за синтаксис, средние — за семантику, поздние — за контекст и согласование.
  • Где происходит «исправление» ошибок Если на раннем слое модель предсказывает неверный токен, а на позднем — правильный, можно понять, какой слой внёс коррекцию.

Это особенно полезно для mechanistic interpretability — исследования внутренних механизмов модели.


4. Связь с архитектурой Transformer

В трансформере информация течёт через residual stream — сумму выходов всех слоёв. Каждый слой (attention + FFN) добавляет свой вклад. Logit lens показывает, как residual stream постепенно накапливает информацию, необходимую для предсказания.

  • Attention отвечает за перемешивание информации между токенами.
  • FFN — за извлечение фактов и преобразование признаков.
  • Logit lens после attention или FFN может показать, какой из компонентов даёт наибольший вклад в финальное предсказание.

5. Ограничения и альтернативы

МетодОписаниеПлюсыМинусы
Logit lensПроекция скрытых состояний через unembeddingПростота, без обученияШум на ранних слоях; unembedding обучен только для последнего слоя
Tuned lensОбучается линейная проекция для каждого слояБолее точные вероятностиТребует дополнительного обучения
ProbingОбучается классификатор поверх скрытых состоянийГибкость (можно предсказывать любые свойства)Нужны размеченные данные; может переобучаться
Activation patchingЗамена активаций на одном слое на другойПозволяет выявить причинно-следственные связиВычислительно дорого

Основное ограничение logit lens: матрица unembedding оптимизирована для последнего слоя, поэтому для ранних слоёв распределение может быть некалиброванным. Тем не менее, качественные паттерны (появление правильного токена) обычно видны.


6. Пример интерпретации с помощью logit lens

Рассмотрим пример с моделью GPT-2 и предложением "The Eiffel Tower is located in". Ожидаемый следующий токен — « Paris».

СлойТоп-3 токена (вероятность)Комментарий
0" the" (0.12), " a" (0.10), " in" (0.08)Шум, общие слова
5" Paris" (0.25), " France" (0.15), " the" (0.10)Появляется правильный ответ
10" Paris" (0.60), " France" (0.20), " city" (0.05)Уверенность растёт
12 (последний)" Paris" (0.85), " France" (0.10), " the" (0.02)Финальное предсказание

Видно, что модель «знает» ответ уже на 5-м слое, а последующие слои лишь усиливают уверенность.


7. Пет-проект для закрепления

Задача: Написать скрипт, который для заданного текста визуализирует, как меняется топ-5 токенов на каждом слое модели (например, GPT-2 или Llama-2-7B).

Инструменты: Python, PyTorch, transformers, matplotlib (или plotly для интерактива).

Шаги:

  1. Загрузить модель с output_hidden_states=True.
  2. Для каждого слоя получить скрытое состояние последнего токена.
  3. Применить lm_head и softmax.
  4. Построить тепловую карту (слои × токены) или анимацию изменения вероятности целевого токена.
  5. Протестировать на нескольких примерах: фактологический запрос, грамматическая конструкция, творческий текст.

Ожидаемый результат: График, на котором видно, на каком слое целевой токен входит в топ-5, и как его вероятность растёт. Это наглядно демонстрирует, как модель «думает» по слоям.


8. Связь с другими вопросами

ВопросТема
290Что такое residual stream в трансформере?
291Как работает attention и зачем он нужен?
294Что такое probing и как он используется для анализа LLM?
296Как работает activation patching?
300Какие методы интерпретации LLM вы знаете?
310Что такое mechanistic interpretability?

Навигация