Aivaro
  • Оглавление
  • Вопросы
  • Практика
  • Вики
  • Материалы сообщества
  • Тесты
  • Поиск
✈Telegram @ai_varo
RUEN中文
…
Оглавление/Вопросы/#935

В чем проблема Vanishing Gradient в RNN и как LSTM её решает?

Краткий тезис

RNN страдает от затухания градиентов (Vanishing Gradient) при обработке длинных последовательностей: градиент ошибки экспоненциально уменьшается при обратном распространении во времени, что блокирует обучение долгосрочным зависимостям. LSTM решает эту проблему, вводя ячейку памяти (cell state) с линейными связями и забывающий вентиль (forget gate), которые позволяют градиенту проходить через временные шаги без существенного изменения.

---------------------|---------------------| | Производная tanh | ≤ 1, часто < 0.25 | | Матрица весов | Собственные числа < 1 → затухание | | Количество шагов (T) | Экспоненциальное уменьшение |

Последствия: RNN запоминает только короткие паттерны (до 5–10 шагов), а для текстов или временных рядов с долгой зависимостью (например, согласование подлежащего в начале предложения с глаголом в конце) выдаёт плохие предсказания.


2. LSTM: ячейка памяти (cell) с линейными связями

LSTM модифицирует архитектуру RNN, вводя отдельный конвейер для долговременной информации — cell state (c_t). В отличие от скрытого состояния (h_t), которое каждый раз обновляется через нелинейность, (c_t) может проходить через ячейку с линейными преобразованиями: сложением и умножением на вентили.

Устройство LSTM-ячейки:

  • (x_t) — входной вектор в момент (t);
  • (h_{t-1}) — скрытое состояние предыдущего шага;
  • (c_{t-1}) — предыдущее состояние ячейки.

Три вентиля (gate) управляются сигмоидой (\sigma), их выходы лежат в (0,1):

  • forget gate (f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)) — решает, что забыть из (c_{t-1});
  • input gate (i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)) — решает, какую новую информацию записать;
  • output gate (o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)) — определяет, какую часть (c_t) выводить в (h_t).

Новый кандидат для ячейки: (\tilde{c}t = [tanh](/wiki/tanh)(W_c \cdot [h{t-1}, x_t] + b_c)).

Обновление (c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t) — линейная комбинация старого состояния и кандидата, взвешенная вентилями.


3. Forget gate решает, что забыть

Forget gate — ключевой механизм для управления долговременной памятью. Он принимает (h_{t-1}) и (x_t), пропускает через сигмоиду и умножает элемент-wise на (c_{t-1}).

  • Если (f_t) близок к 1 → информация сохраняется почти без изменений.
  • Если (f_t) близок к 0 → ячейка "забывает" соответствующие компоненты.

Благодаря этому LSTM может избирательно стирать ненужную информацию, не затрагивая остальные каналы памяти. Это решает проблему "взрыва/затухания" при умножении: если модель решает, что контекст ещё актуален, (f_t) остаётся высоким, и градиент не обрезается.


4. Градиент может проходить неизменным через cell

Самое важное отличие LSTM от простой RNN — carousel (конвейер) для градиента. В формуле (c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}t) при обратном распространении градиент (\frac{\partial L}{\partial c_t}) частично переходит к (\frac{\partial L}{\partial c{t-1}}) через (f_t). Если (f_t = 1) (забываемый вентиль полностью открыт), то градиент передаётся один к одному, без умножения на веса и нелинейности:

[ \frac{\partial c_t}{\partial c_{t-1}} = f_t \quad ([text](/wiki/text){при условии, что } i_t,\tilde{c}t [text](/wiki/text){ не зависят от } c{t-1}). ]

Таким образом, на пути от (c_t) к (c_{t-1}) стоит только множитель (f_t) (от 0 до 1), а не произведение матриц и тангенсов. Если forget gate установлен в 1 для нужного числа шагов, градиент может "пробегать" через десятки шагов без затухания. Это делает LSTM эффективным для задач, требующих запоминания на сотни и тысячи шагов (например, распознавание речи, машинный перевод).

Практический эффект: LSTM стала стандартом для последовательностей до появления Transformer. Она стабильно обучается на длинных контекстах, тогда как простая RNN — нет.


5. Пет-проект для закрепления

Задача: реализовать с нуля на PyTorch RNN и LSTM (по одной ячейке) для предсказания следующего символа в искусственно созданном тексте с долгой зависимостью (например, "a...b" с расстоянием в 50 символов). Сравнить значения градиентов первых шагов при BPTT.

Инструменты:

  • Python, PyTorch, NumPy.
  • Jupyter Notebook для визуализации.

Шаги:

  1. Сгенерировать последовательность вида: "X...Y", где X и Y имеют семантическую связь (например, "a" и "b" — открывающий и закрывающий теги, между ними шум).
  2. Реализовать класс SimpleRNN с одной рекуррентной ячейкой и LSTMCell.
  3. Обучить обе модели на задаче прогноза следующего символа (cross-entropy).
  4. В процессе обучения сохранять градиенты по весам для каждого временного шага (hook'и или retain_grad).
  5. Построить график нормы градиента от номера шага (backward).
  6. Сделать вывод: в RNN градиент падает к нулю за несколько шагов, в LSTM — остаётся значимым на десятки шагов.

Ожидаемый результат:

  • RNN не сможет выучить зависимость между X и Y (accurary ~ случайная).
  • LSTM покажет значительно более высокую точность (например, 90%+).
  • График градиента подтвердит отсутствие затухания в LSTM.

Связь с другими вопросами

ВопросТема
934Основы RNN и проблема долгосрочных зависимостей
(доп.) 936Сравнение LSTM и GRU — почему GRU проще, но также решает vanishing gradient

Навигация

  • Предыдущий: 934
  • Следующий: 936
  • Индекс: 00. Индекс разборов