Что такое torch.compile и как он ускоряет training?
Краткий тезис
torch.compile — это JIT-компилятор PyTorch, который преобразует код, написанный в eager mode (пооператорное выполнение), в оптимизированный graph|вычислительный граф. Он ускоряет training на 20–40% за счёт фузии операций, уменьшения накладных расходов на Python-интерпретатор и оптимизации использования памяти. Однако компилятор чувствителен к dynamic shapes (изменяющимся размерам тензоров), что может вызывать перекомпиляции и снижать выигрыш, особенно при обучении LLM.
1. Проблема eager execution в PyTorch
По умолчанию PyTorch работает в eager mode: каждая операция (например, torch.matmul, torch.relu) выполняется сразу, а Python-код управляет потоком данных. Это даёт гибкость (можно использовать if, циклы, отладку), но приводит к накладным расходам:
- Kernel launch overhead — каждый вызов операции запускает отдельное ядро на GPU, что занимает микросекунды, но в сумме даёт заметную задержку.
- Python overhead — интерпретатор Python обрабатывает каждую строку, создаёт временные объекты и управляет памятью.
- Отсутствие глобальных оптимизаций — eager mode не видит весь граф вычислений]], поэтому не может, например, объединить последовательные операции в одно ядро.
Сравнение eager и compiled режимов:
| Характеристика | Eager mode | Compiled mode (torch.compile) |
|---|---|---|
| Гибкость | Высокая | Средняя (ограничения на dynamic control flow) |
| Скорость выполнения | Низкая (из-за overhead) | Высокая (оптимизированные fused kernels) |
| Использование памяти | Больше временных буферов | Меньше за счёт переиспользования |
| Отладка | Лёгкая | Сложнее (граф скрыт) |
2. Как работает torch.compile: три этапа
torch.compile использует архитектуру из двух ключевых компонентов: TorchDynamo (захват графа) и TorchInductor (генерация кода]]). Процесс состоит из трёх этапов:
2.1 Захват графа (TorchDynamo)
TorchDynamo перехватывает выполнение Python-функции и строит FX-граф (представление вычислений в виде узлов-операций). Он анализирует байт-код Python и определяет, какие операции можно скомпилировать. Если встречается некомпилируемый код (например, вызов внешней библиотеки), Dynamo возвращается к eager mode для этого участка.
2.2 Компиляция (TorchInductor)
Полученный граф передаётся в TorchInductor — backend, который генерирует оптимизированный код для целевого устройства (GPU, CPU). Inductor применяет:
- Fusion — объединение последовательных операций (например, matmul + bias + relu) в одно ядро.
- Triton kernels — генерация высокопроизводительных ядер на языке Triton (для NVIDIA GPU).
- Memory planning — переиспользование буферов и уменьшение числа аллокаций.
2.3 Выполнение
Скомпилированная функция заменяет исходную. При повторных вызовах с теми же размерами тензоров используется кэшированная версия, что даёт ускорение.
import torch
def train_step(model, x, y):
pred = model(x)
loss = torch.nn.functional.cross_entropy(pred, y)
loss.backward()
return loss
# Обычный eager вызов
loss = train_step(model, x, y)
# Скомпилированный вызов
compiled_step = torch.compile(train_step)
loss = compiled_step(model, x, y)
3. Режимы компиляции
torch.compile предлагает три основных режима, управляемых параметром mode:
| Режим | Описание | Типичное ускорение | Когда использовать |
|---|---|---|---|
"default" | Баланс между временем компиляции и производительностью | 20–30% | Общий случай |
"reduce-overhead" | Минимизация накладных расходов на запуск ядер | 30–40% | Модели с большим числом мелких операций |
"max-autotune" | Длительная автопоисковая оптимизация (пробует разные конфигурации ядер) | до 50% | Когда время компиляции не критично (например, долгий training) |
Пример использования:
model = torch.compile(model, mode="reduce-overhead")
4. За счёт чего достигается ускорение training
Ускорение в forward + backward складывается из нескольких факторов:
- Fusion операций — несколько последовательных операций объединяются в одно ядро. Например, linear -> layer_norm -> dropout может быть выполнено за один проход.
- Уменьшение числа kernel launches — вместо десятков вызовов ядер GPU выполняется несколько, что снижает latency.
- Оптимизация памяти — компилятор переиспользует временные буферы, уменьшая количество аллокаций и освобождений.
- Автоматическая настройка параметров ядер — Triton-ядра подбирают размеры блоков и warp-ов под конкретную архитектуру GPU.
Эксперименты PyTorch team показывают, что на типовых моделях (ResNet, BERT, GPT) ускорение составляет 20–40% для training и до 60% для inference.
5. Ограничения torch.compile
Несмотря на преимущества, torch.compile имеет ряд ограничений, критичных для LLM:
5.1 Dynamic shapes
Если размеры тензоров меняются от шага к шагу (например, разная длина последовательности в батче), компилятор не может закэшировать оптимизированное ядро и вынужден перекомпилировать граф при каждом новом shape. Это может свести на нет выигрыш в скорости.
Решение для LLM использовать static shapes — паддинг до фиксированной длины, bucketing (группировка последовательностей похожей длины) или отключать compile для частей модели, работающих с dynamic shapes.
5.2 Несовместимость с некоторыми операциями
Некоторые операции (например, torch.where с динамическим условием, torch.nonzero, сложные control flow) не поддерживаются или приводят к возврату в eager mode. В таких случаях компиляция не даёт ускорения.
5.3 Время первой компиляции
Первый вызов скомпилированной функции занимает больше времени из-за захвата графа и генерации кода. Для коротких training-запусков это может быть невыгодно.
5.4 Совместимость с распределённым обучением
Не все стратегии распределённого обучения (например, FSDP, DeepSpeed) полностью совместимы с torch.compile. Требуется тестирование.
6. Практические рекомендации для LLM
При обучении больших языковых моделей (LLM) torch.compile следует использовать с осторожностью:
- Используйте static shapes — фиксируйте длину последовательности (например, 512 или 1024 токена) с помощью паддинга. Это позволит избежать перекомпиляций.
- Отключайте compile для генерации — во время инференса (генерации текста) dynamic shapes неизбежны, поэтому лучше оставить eager mode для шагов декодирования.
- Применяйте compile только к forward — backward может быть менее выгоден из-за большого числа операций градиентов. Можно скомпилировать только forward-функцию.
- Профилируйте — используйте torch.profiler для сравнения времени с и без compile. Если ускорение менее 10%, возможно, стоит отказаться.
Пример для LLM:
class LLMModel(nn.Module):
def forward(self, input_ids, attention_mask):
# ... forward pass
return logits
model = LLMModel()
# Компилируем только forward для статичных размеров
model.forward = torch.compile(model.forward, mode="default", dynamic=False)
7. Сравнение с альтернативами
| Инструмент | Механизм | Преимущества | Недостатки |
|---|---|---|---|
| torch.compile | JIT-компиляция через TorchDynamo + TorchInductor | Простота использования, интеграция с PyTorch, поддержка GPU | Ограничения на dynamic shapes, время первой компиляции |
| XLA (JAX) | Ahead-of-time компиляция всего графа | Высокая производительность, поддержка TPU | Другой фреймворк (JAX), сложнее отладка |
| TensorRT | Оптимизация для инференса NVIDIA GPU | Максимальная скорость для inference | Не подходит для training, требует конвертации модели |
| ONNX Runtime | Кроссплатформенный оптимизатор | Поддержка разных бэкендов (CPU, GPU, NPU) | Меньшая гибкость, чем torch.compile |
8. Пример кода: замер ускорения
import torch
import time
# Простая модель
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc1 = torch.nn.Linear(1024, 4096)
self.fc2 = torch.nn.Linear(4096, 1024)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
model = SimpleModel().cuda()
x = torch.randn(64, 1024).cuda()
y = torch.randn(64, 1024).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# Eager mode
def train_eager():
optimizer.zero_grad()
out = model(x)
loss = torch.nn.functional.mse_loss(out, y)
loss.backward()
optimizer.step()
# Compiled mode
compiled_model = torch.compile(model)
def train_compiled():
optimizer.zero_grad()
out = compiled_model(x)
loss = torch.nn.functional.mse_loss(out, y)
loss.backward()
optimizer.step()
# Замер времени
import timeit
eager_time = timeit.timeit(train_eager, number=100)
compiled_time = timeit.timeit(train_compiled, number=100)
print(f"Eager: {eager_time:.3f}s, Compiled: {compiled_time:.3f}s, Speedup: {eager_time/compiled_time:.2f}x")
Ожидаемый результат: ускорение в 1.2–1.5 раза (20–50%).
Пет-проект для закрепления
Задача Сравнить скорость обучения небольшой трансформерной модели (например, 4 слоя, 4 головы внимания) на задаче классификации текстов с использованием torch.compile и без него.
Инструменты PyTorch, torch.compile, torch.profiler, timeit, датасет (например, AG_NEWS из torchtext).
Шаги:
- Реализовать модель трансформера (или взять готовую из torch.nn.TransformerEncoder).
- Написать цикл обучения с фиксированной длиной последовательности (static shape).
- Замерить время одной эпохи в eager mode.
- Обернуть модель в
torch.compile(model, mode="default")и замерить время. - Повторить для dynamic shapes (разная длина последовательностей) и зафиксировать перекомпиляции.
- Построить график ускорения в зависимости от размера батча.
Ожидаемый результат Ускорение 20–40% для static shapes; при dynamic shapes ускорение может быть меньше или отсутствовать.
Связь с другими вопросами
| Вопрос | Тема |
|---|---|
| 470 | Оптимизация памяти при обучении (градиентный чекпоинтинг, аккумуляция градиентов) |
| 471 | Смешанная точность (AMP) и её влияние на скорость |
| 472 | DeepSpeed и ZeRO для распределённого обучения |
| 474 | Профилирование производительности PyTorch (torch.profiler) |
| 475 | Оптимизация DataLoader и предобработки данных |
Навигация
- Предыдущий: 472
- Следующий: 474
- Индекс: 00. Индекс разборов