Что такое torch.compile и как он ускоряет training?

Краткий тезис

torch.compile — это JIT-компилятор PyTorch, который преобразует код, написанный в eager mode (пооператорное выполнение), в оптимизированный graph|вычислительный граф. Он ускоряет training на 20–40% за счёт фузии операций, уменьшения накладных расходов на Python-интерпретатор и оптимизации использования памяти. Однако компилятор чувствителен к dynamic shapes (изменяющимся размерам тензоров), что может вызывать перекомпиляции и снижать выигрыш, особенно при обучении LLM.


1. Проблема eager execution в PyTorch

По умолчанию PyTorch работает в eager mode: каждая операция (например, torch.matmul, torch.relu) выполняется сразу, а Python-код управляет потоком данных. Это даёт гибкость (можно использовать if, циклы, отладку), но приводит к накладным расходам:

  • Kernel launch overhead — каждый вызов операции запускает отдельное ядро на GPU, что занимает микросекунды, но в сумме даёт заметную задержку.
  • Python overhead — интерпретатор Python обрабатывает каждую строку, создаёт временные объекты и управляет памятью.
  • Отсутствие глобальных оптимизаций — eager mode не видит весь граф вычислений]], поэтому не может, например, объединить последовательные операции в одно ядро.

Сравнение eager и compiled режимов:

ХарактеристикаEager modeCompiled mode (torch.compile)
ГибкостьВысокаяСредняя (ограничения на dynamic control flow)
Скорость выполненияНизкая (из-за overhead)Высокая (оптимизированные fused kernels)
Использование памятиБольше временных буферовМеньше за счёт переиспользования
ОтладкаЛёгкаяСложнее (граф скрыт)

2. Как работает torch.compile: три этапа

torch.compile использует архитектуру из двух ключевых компонентов: TorchDynamo (захват графа) и TorchInductor (генерация кода]]). Процесс состоит из трёх этапов:

2.1 Захват графа (TorchDynamo)

TorchDynamo перехватывает выполнение Python-функции и строит FX-граф (представление вычислений в виде узлов-операций). Он анализирует байт-код Python и определяет, какие операции можно скомпилировать. Если встречается некомпилируемый код (например, вызов внешней библиотеки), Dynamo возвращается к eager mode для этого участка.

2.2 Компиляция (TorchInductor)

Полученный граф передаётся в TorchInductor — backend, который генерирует оптимизированный код для целевого устройства (GPU, CPU). Inductor применяет:

  • Fusion — объединение последовательных операций (например, matmul + bias + relu) в одно ядро.
  • Triton kernels — генерация высокопроизводительных ядер на языке Triton (для NVIDIA GPU).
  • Memory planning — переиспользование буферов и уменьшение числа аллокаций.

2.3 Выполнение

Скомпилированная функция заменяет исходную. При повторных вызовах с теми же размерами тензоров используется кэшированная версия, что даёт ускорение.

import torch

def train_step(model, x, y):
    pred = model(x)
    loss = torch.nn.functional.cross_entropy(pred, y)
    loss.backward()
    return loss

# Обычный eager вызов
loss = train_step(model, x, y)

# Скомпилированный вызов
compiled_step = torch.compile(train_step)
loss = compiled_step(model, x, y)

3. Режимы компиляции

torch.compile предлагает три основных режима, управляемых параметром mode:

РежимОписаниеТипичное ускорениеКогда использовать
"default"Баланс между временем компиляции и производительностью20–30%Общий случай
"reduce-overhead"Минимизация накладных расходов на запуск ядер30–40%Модели с большим числом мелких операций
"max-autotune"Длительная автопоисковая оптимизация (пробует разные конфигурации ядер)до 50%Когда время компиляции не критично (например, долгий training)

Пример использования:

model = torch.compile(model, mode="reduce-overhead")

4. За счёт чего достигается ускорение training

Ускорение в forward + backward складывается из нескольких факторов:

  • Fusion операций — несколько последовательных операций объединяются в одно ядро. Например, linear -> layer_norm -> dropout может быть выполнено за один проход.
  • Уменьшение числа kernel launches — вместо десятков вызовов ядер GPU выполняется несколько, что снижает latency.
  • Оптимизация памяти — компилятор переиспользует временные буферы, уменьшая количество аллокаций и освобождений.
  • Автоматическая настройка параметров ядер — Triton-ядра подбирают размеры блоков и warp-ов под конкретную архитектуру GPU.

Эксперименты PyTorch team показывают, что на типовых моделях (ResNet, BERT, GPT) ускорение составляет 20–40% для training и до 60% для inference.


5. Ограничения torch.compile

Несмотря на преимущества, torch.compile имеет ряд ограничений, критичных для LLM:

5.1 Dynamic shapes

Если размеры тензоров меняются от шага к шагу (например, разная длина последовательности в батче), компилятор не может закэшировать оптимизированное ядро и вынужден перекомпилировать граф при каждом новом shape. Это может свести на нет выигрыш в скорости.

Решение для LLM использовать static shapes — паддинг до фиксированной длины, bucketing (группировка последовательностей похожей длины) или отключать compile для частей модели, работающих с dynamic shapes.

5.2 Несовместимость с некоторыми операциями

Некоторые операции (например, torch.where с динамическим условием, torch.nonzero, сложные control flow) не поддерживаются или приводят к возврату в eager mode. В таких случаях компиляция не даёт ускорения.

5.3 Время первой компиляции

Первый вызов скомпилированной функции занимает больше времени из-за захвата графа и генерации кода. Для коротких training-запусков это может быть невыгодно.

5.4 Совместимость с распределённым обучением

Не все стратегии распределённого обучения (например, FSDP, DeepSpeed) полностью совместимы с torch.compile. Требуется тестирование.


6. Практические рекомендации для LLM

При обучении больших языковых моделей (LLM) torch.compile следует использовать с осторожностью:

  • Используйте static shapes — фиксируйте длину последовательности (например, 512 или 1024 токена) с помощью паддинга. Это позволит избежать перекомпиляций.
  • Отключайте compile для генерации — во время инференса (генерации текста) dynamic shapes неизбежны, поэтому лучше оставить eager mode для шагов декодирования.
  • Применяйте compile только к forward — backward может быть менее выгоден из-за большого числа операций градиентов. Можно скомпилировать только forward-функцию.
  • Профилируйте — используйте torch.profiler для сравнения времени с и без compile. Если ускорение менее 10%, возможно, стоит отказаться.

Пример для LLM:

class LLMModel(nn.Module):
    def forward(self, input_ids, attention_mask):
        # ... forward pass
        return logits

model = LLMModel()
# Компилируем только forward для статичных размеров
model.forward = torch.compile(model.forward, mode="default", dynamic=False)

7. Сравнение с альтернативами

ИнструментМеханизмПреимуществаНедостатки
torch.compileJIT-компиляция через TorchDynamo + TorchInductorПростота использования, интеграция с PyTorch, поддержка GPUОграничения на dynamic shapes, время первой компиляции
XLA (JAX)Ahead-of-time компиляция всего графаВысокая производительность, поддержка TPUДругой фреймворк (JAX), сложнее отладка
TensorRTОптимизация для инференса NVIDIA GPUМаксимальная скорость для inferenceНе подходит для training, требует конвертации модели
ONNX RuntimeКроссплатформенный оптимизаторПоддержка разных бэкендов (CPU, GPU, NPU)Меньшая гибкость, чем torch.compile

8. Пример кода: замер ускорения

import torch
import time

# Простая модель
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(1024, 4096)
        self.fc2 = torch.nn.Linear(4096, 1024)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleModel().cuda()
x = torch.randn(64, 1024).cuda()
y = torch.randn(64, 1024).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Eager mode
def train_eager():
    optimizer.zero_grad()
    out = model(x)
    loss = torch.nn.functional.mse_loss(out, y)
    loss.backward()
    optimizer.step()

# Compiled mode
compiled_model = torch.compile(model)
def train_compiled():
    optimizer.zero_grad()
    out = compiled_model(x)
    loss = torch.nn.functional.mse_loss(out, y)
    loss.backward()
    optimizer.step()

# Замер времени
import timeit
eager_time = timeit.timeit(train_eager, number=100)
compiled_time = timeit.timeit(train_compiled, number=100)
print(f"Eager: {eager_time:.3f}s, Compiled: {compiled_time:.3f}s, Speedup: {eager_time/compiled_time:.2f}x")

Ожидаемый результат: ускорение в 1.2–1.5 раза (20–50%).


Пет-проект для закрепления

Задача Сравнить скорость обучения небольшой трансформерной модели (например, 4 слоя, 4 головы внимания) на задаче классификации текстов с использованием torch.compile и без него.

Инструменты PyTorch, torch.compile, torch.profiler, timeit, датасет (например, AG_NEWS из torchtext).

Шаги:

  1. Реализовать модель трансформера (или взять готовую из torch.nn.TransformerEncoder).
  2. Написать цикл обучения с фиксированной длиной последовательности (static shape).
  3. Замерить время одной эпохи в eager mode.
  4. Обернуть модель в torch.compile(model, mode="default") и замерить время.
  5. Повторить для dynamic shapes (разная длина последовательностей) и зафиксировать перекомпиляции.
  6. Построить график ускорения в зависимости от размера батча.

Ожидаемый результат Ускорение 20–40% для static shapes; при dynamic shapes ускорение может быть меньше или отсутствовать.


Связь с другими вопросами

ВопросТема
470Оптимизация памяти при обучении (градиентный чекпоинтинг, аккумуляция градиентов)
471Смешанная точность (AMP) и её влияние на скорость
472DeepSpeed и ZeRO для распределённого обучения
474Профилирование производительности PyTorch (torch.profiler)
475Оптимизация DataLoader и предобработки данных

Навигация