Как работает кросс-энтропия (cross-entropy loss) для LLM обучения?

Краткий тезис

Кросс-энтропия — это стандартная функция потерь для обучения языковых моделей (LLM) на задаче предсказания следующего токена (token token prediction|next token prediction). Она измеряет разницу между истинным распределением вероятностей (one-hot вектор правильного токена) и предсказанным распределением (softmax от логитов модели). Минимизация кросс-энтропии эквивалентна максимизации логарифма вероятности правильного токена, что напрямую соответствует принципу максимального правдоподобия (MLE). Градиент этой функции имеет простую форму p - y_true, что упрощает обратное распространение.


1. Термин: Кросс-энтропия (Cross-Entropy)

Кросс-энтропия — это мера из теории информации, которая оценивает, насколько распределение q (предсказанное моделью) отличается от истинного распределения p. Для дискретных распределений формула:

H(p, q) = - Σ p(x) * log q(x)

В контексте LLM p — это one-hot вектор (единица на позиции правильного токена, нули на остальных), а qsoftmax вероятности, выданные моделью для всех токенов словаря.

Почему именно кросс-энтропия

  • Она выпукла и дифференцируема, что удобно для градиентного спуска.
  • Её минимум достигается, когда q совпадает с p, то есть модель присваивает вероятность 1 правильному токену.
  • Она напрямую связана с перплексией (perplexity): perplexity = exp(H).

2. Роль кросс-энтропии в обучении LLM

LLM обучаются на огромных корпусах текстов в режиме self-supervised learning. Для каждого токена в последовательности модель предсказывает следующий токен, используя все предыдущие. Функция потерь считается для каждого предсказания и усредняется по всей последовательности.

Пример:

  • Вход: "Сегодня хорошая"
  • Цель (target): "погода"
  • Модель выдаёт логиты z размером vocab_size.
  • Softmax превращает логиты в вероятности q.
  • Кросс-энтропия: -1 * log(q["погода"]) (так как one-hot даёт 1 только для "погода").

Таким образом, обучение сводится к тому, чтобы увеличить вероятность правильного токена и уменьшить вероятности всех остальных.


3. Формула и её компоненты

Полная формула для одного примера (одного токена):

L = - Σ_{i=1}^{V} y_i * log(ŷ_i)

где:

  • V — размер словаря (vocabulary size).
  • y_i — истинная метка (0 или 1, one-hot).
  • ŷ_i — предсказанная вероятность i-го токена (после softmax).

Для батча из N токенов (усреднение):

L = - (1/N) * Σ_{j=1}^{N} Σ_{i=1}^{V} y_{j,i} * log(ŷ_{j,i})

Связь с softmax Softmax преобразует логиты z_i в вероятности:

ŷ_i = exp(z_i) / Σ_{k=1}^{V} exp(z_k)

Тогда кросс-энтропия принимает вид:

L = - log( exp(z_c) / Σ_k exp(z_k) ) = - z_c + log( Σ_k exp(z_k) )

где c — индекс правильного токена. Это log-softmax + negative log-likelihood.


4. Градиент кросс-энтропии

Градиент по логитам z имеет удивительно простую форму:

∂L / ∂z_i = ŷ_i - y_i

Вывод (кратко):

  • Для правильного токена c: ∂L/∂z_c = ŷ_c - 1.
  • Для остальных токенов i ≠ c: ∂L/∂z_i = ŷ_i - 0 = ŷ_i.

Это означает, что градиент направлен на уменьшение вероятности всех токенов, кроме правильного, и на увеличение вероятности правильного. Такая форма делает обратное распространение эффективным и численно стабильным.

Пример:

  • Правильный токен: "погода" (индекс 42).
  • ŷ_42 = 0.3, ŷ_100 = 0.1.
  • Градиент по z_42: 0.3 - 1 = -0.7 (увеличиваем логит).
  • Градиент по z_100: 0.1 - 0 = 0.1 (уменьшаем логит).

5. Интерпретация: максимизация логарифма правдоподобия

Минимизация кросс-энтропии эквивалентна максимизации логарифма правдоподобия (MLE). Для одного токена:

L = - log P(правильный токен | контекст)

Суммируя по всем токенам в корпусе, мы максимизируем логарифм вероятности всех данных. Это фундаментальный принцип обучения генеративных моделей.

Связь с перплексией Перплексия = exp(L) (для одного токена). Чем ниже перплексия, тем лучше модель предсказывает данные.


6. Проблемы и улучшения

6.1 Численная стабильность

Прямое вычисление log(softmax) может привести к переполнению из-за больших экспонент. На практике используют LogSoftmax (функция log_softmax в PyTorch) и CrossEntropyLoss, который объединяет softmax и negative log-likelihood в одну численно устойчивую операцию.

6.2 Label Smoothing

Жёсткое one-hot кодирование может приводить к переобучению и излишней уверенности модели. Label smoothing заменяет one-hot на мягкое распределение:

y_i' = (1 - ε) * y_i + ε / V

где ε — небольшой коэффициент (например, 0.1). Это улучшает обобщение и калибровку вероятностей.

6.3 Взвешенная кросс-энтропия

Если классы несбалансированы (редкие токены), можно добавить веса:

L = - Σ w_i * y_i * log(ŷ_i)

где w_i обратно пропорциональны частоте токена.


7. Сравнение с другими функциями потерь

Функция потерьФормулаПрименение в NLP
Cross-Entropy-Σ y log ŷСтандарт для классификации токенов
Mean Squared Error (MSE)Σ (y - ŷ)²Регрессия, не подходит для вероятностей
Hinge Lossmax(0, 1 - y*ŷ)SVM, редко в LLM
Contrastive Loss(1-y)*d² + y*max(0, margin-d)²Обучение эмбеддингов

Кросс-энтропия предпочтительнее MSE для вероятностных задач, так как она сильнее штрафует за большие ошибки в логарифмическом масштабе и лучше работает с softmax.


8. Реализация на PyTorch

import torch
import torch.nn as nn

# Гипотетические логиты для батча из 2 примеров, словарь из 1000 токенов
logits = torch.randn(2, 1000)          # (batch_size, vocab_size)
targets = torch.tensor([42, 315])      # индексы правильных токенов

# Вариант 1: ручное вычисление (не рекомендуется из-за стабильности)
# softmax = torch.softmax(logits, dim=-1)
# loss = -torch.log(softmax[range(2), targets]).mean()

# Вариант 2: встроенная функция (численно устойчивая)
criterion = nn.CrossEntropyLoss()
loss = criterion(logits, targets)      # уже включает softmax внутри

print(loss.item())                     # значение потерь

Важно CrossEntropyLoss в PyTorch ожидает на вход логиты, а не вероятности, и применяет softmax автоматически.


9. Пет-проект для закрепления

Задача Обучить маленькую языковую модель (например, 2-слойный LSTM) на небольшом корпусе (сказки Пушкина) и визуализировать динамику кросс-энтропии и перплексии.

Инструменты PyTorch, torchtext, matplotlib.

Шаги:

  1. Загрузить и токенизировать текст (построить словарь ~5000 токенов).
  2. Создать даталоадер с последовательностями длины 64.
  3. Определить модель: EmbeddingLSTMLinear → vocab_size.
  4. Обучить с CrossEntropyLoss и оптимизатором Adam.
  5. Логировать loss и perplexity на каждой эпохе.
  6. Построить график: loss по шагам, perplexity по эпохам.
  7. Попробовать label smoothing (ε=0.1) и сравнить кривые.

Ожидаемый результат Понимание, как кросс-энтропия уменьшается в процессе обучения, как перплексия падает, и как label smoothing влияет на уверенность модели.


10. Связь с другими вопросами

ВопросТема
650Как работает softmax и зачем он нужен в LLM?
651Что такое перплексия и как она связана с loss?
652Как устроен процесс обучения LLM (pre-training)?
653Что такое label smoothing и зачем он применяется?
654Как вычисляется градиент в нейронных сетях?
660Как работает fine-tuning LLM на своей задаче?

11. Навигация


Навигация