В чем проблема Vanishing Gradient в RNN и как LSTM её решает?
Краткий тезис
RNN страдает от затухания градиентов (Vanishing Gradient) при обработке длинных последовательностей: градиент ошибки экспоненциально уменьшается при обратном распространении во времени, что блокирует обучение долгосрочным зависимостям. LSTM решает эту проблему, вводя ячейку памяти (cell state) с линейными связями и забывающий вентиль (forget gate), которые позволяют градиенту проходить через временные шаги без существенного изменения.
---------------------|---------------------| | Производная tanh | ≤ 1, часто < 0.25 | | Матрица весов | Собственные числа < 1 → затухание | | Количество шагов (T) | Экспоненциальное уменьшение |
Последствия: RNN запоминает только короткие паттерны (до 5–10 шагов), а для текстов или временных рядов с долгой зависимостью (например, согласование подлежащего в начале предложения с глаголом в конце) выдаёт плохие предсказания.
2. LSTM: ячейка памяти (cell) с линейными связями
LSTM модифицирует архитектуру RNN, вводя отдельный конвейер для долговременной информации — cell state (c_t). В отличие от скрытого состояния (h_t), которое каждый раз обновляется через нелинейность, (c_t) может проходить через ячейку с линейными преобразованиями: сложением и умножением на вентили.
Устройство LSTM-ячейки:
- (x_t) — входной вектор в момент (t);
- (h_{t-1}) — скрытое состояние предыдущего шага;
- (c_{t-1}) — предыдущее состояние ячейки.
Три вентиля (gate) управляются сигмоидой (\sigma), их выходы лежат в (0,1):
- forget gate (f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)) — решает, что забыть из (c_{t-1});
- input gate (i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)) — решает, какую новую информацию записать;
- output gate (o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)) — определяет, какую часть (c_t) выводить в (h_t).
Новый кандидат для ячейки: (\tilde{c}t = [tanh](/wiki/tanh)(W_c \cdot [h{t-1}, x_t] + b_c)).
Обновление (c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t) — линейная комбинация старого состояния и кандидата, взвешенная вентилями.
3. Forget gate решает, что забыть
Forget gate — ключевой механизм для управления долговременной памятью. Он принимает (h_{t-1}) и (x_t), пропускает через сигмоиду и умножает элемент-wise на (c_{t-1}).
- Если (f_t) близок к 1 → информация сохраняется почти без изменений.
- Если (f_t) близок к 0 → ячейка "забывает" соответствующие компоненты.
Благодаря этому LSTM может избирательно стирать ненужную информацию, не затрагивая остальные каналы памяти. Это решает проблему "взрыва/затухания" при умножении: если модель решает, что контекст ещё актуален, (f_t) остаётся высоким, и градиент не обрезается.
4. Градиент может проходить неизменным через cell
Самое важное отличие LSTM от простой RNN — carousel (конвейер) для градиента. В формуле (c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}t) при обратном распространении градиент (\frac{\partial L}{\partial c_t}) частично переходит к (\frac{\partial L}{\partial c{t-1}}) через (f_t). Если (f_t = 1) (забываемый вентиль полностью открыт), то градиент передаётся один к одному, без умножения на веса и нелинейности:
[ \frac{\partial c_t}{\partial c_{t-1}} = f_t \quad ([text](/wiki/text){при условии, что } i_t,\tilde{c}t [text](/wiki/text){ не зависят от } c{t-1}). ]
Таким образом, на пути от (c_t) к (c_{t-1}) стоит только множитель (f_t) (от 0 до 1), а не произведение матриц и тангенсов. Если forget gate установлен в 1 для нужного числа шагов, градиент может "пробегать" через десятки шагов без затухания. Это делает LSTM эффективным для задач, требующих запоминания на сотни и тысячи шагов (например, распознавание речи, машинный перевод).
Практический эффект: LSTM стала стандартом для последовательностей до появления Transformer. Она стабильно обучается на длинных контекстах, тогда как простая RNN — нет.
5. Пет-проект для закрепления
Задача: реализовать с нуля на PyTorch RNN и LSTM (по одной ячейке) для предсказания следующего символа в искусственно созданном тексте с долгой зависимостью (например, "a...b" с расстоянием в 50 символов). Сравнить значения градиентов первых шагов при BPTT.
Инструменты:
Шаги:
- Сгенерировать последовательность вида: "X...Y", где X и Y имеют семантическую связь (например, "a" и "b" — открывающий и закрывающий теги, между ними шум).
- Реализовать класс
SimpleRNNс одной рекуррентной ячейкой иLSTMCell. - Обучить обе модели на задаче прогноза следующего символа (cross-entropy).
- В процессе обучения сохранять градиенты по весам для каждого временного шага (hook'и или retain_grad).
- Построить график нормы градиента от номера шага (backward).
- Сделать вывод: в RNN градиент падает к нулю за несколько шагов, в LSTM — остаётся значимым на десятки шагов.
Ожидаемый результат:
- RNN не сможет выучить зависимость между X и Y (accurary ~ случайная).
- LSTM покажет значительно более высокую точность (например, 90%+).
- График градиента подтвердит отсутствие затухания в LSTM.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 934 | Основы RNN и проблема долгосрочных зависимостей |
| (доп.) 936 | Сравнение LSTM и GRU — почему GRU проще, но также решает vanishing gradient |
Навигация
- Предыдущий: 934
- Следующий: 936
- Индекс: 00. Индекс разборов