English translation is not available yet. Showing Russian content.

Как работает graph optimization в LLM компиляторах (constant folding, dead code elimination)?

Краткий тезис

Graph optimization в LLM компиляторах — это набор техник, применяемых к вычислительному графу модели для уменьшения времени выполнения и использования памяти. Constant folding вычисляет константные выражения на этапе компиляции, заменяя их готовыми значениями. Dead code elimination удаляет операции, результаты которых не используются. Эти оптимизации снижают launch overhead (накладные расходы на запуск ядер) и ускоряют инференс, особенно на небольших батчах.


1. Что такое LLM компиляторы и зачем нужна оптимизация графа

LLM компилятор — это специализированная программа, которая преобразует высокоуровневое описание нейросети (например, модель PyTorch или TensorFlow) в эффективный исполняемый код для конкретного оборудования (GPU, CPU, NPU). В отличие от традиционных компиляторов (gcc, LLVM), LLM компиляторы работают с вычислительным графом — направленным ациклическим графом, где узлы — операции (сложение, умножение матриц, активации), а рёбра — тензоры данных.

Оптимизация графа необходима, потому что:

  • Исходный граф, построенный фреймворком, часто содержит избыточные операции (например, повторные вычисления одних и тех же констант).
  • Без оптимизации каждый вызов модели запускает множество мелких ядер, что увеличивает overhead|launch overhead|launch overhead.
  • Современные LLM содержат миллиарды параметров, и даже небольшие улучшения в графе дают значительный прирост производительности.

2. Вычислительный граф и его представление

graph|Вычислительный граф может быть представлен в виде IR (regression|Intermediate Representation) — промежуточного представления. Примеры IR: MLIR (Multi-Level Intermediate Representation), XLA HLO, TVM Relay. IR позволяет применять оптимизации, не зависящие от исходного фреймворка.

Простой пример графа (псевдокод):

a = Constant(2.0)
b = Constant(3.0)
c = Add(a, b)   # 2.0 + 3.0 = 5.0
d = Mul(c, x)   # 5.0 * x
e = Sin(d)
f = Exp(e)
g = Add(f, c)   # использует c
h = Mul(g, y)
output = h

Здесь a, b — константы, c — результат сложения констант (тоже константа), x, y — входные переменные.


3. Constant folding (свёртка констант)

Constant folding — это оптимизация, при которой выражения, все операнды которых являются константами, вычисляются на этапе компиляции, а не во время выполнения. Результат заменяется на новую константу.

Как работает

  1. Компилятор обходит граф и находит узлы, у которых все входы — константы.
  2. Вычисляет значение такого узла (например, Add(2.0, 3.0) → 5.0).
  3. Заменяет узел на новую константу.
  4. Удаляет исходные константы, если они больше не используются (это уже часть code elimination|code elimination|dead code elimination).

Пример до/после

До:
  a = Constant(2.0)
  b = Constant(3.0)
  c = Add(a, b)   # 5.0
  d = Mul(c, x)   # 5.0 * x

После:
  c = Constant(5.0)
  d = Mul(c, x)   # 5.0 * x

Узлы a, b и операция Add удалены.

Влияние на LLM В моделях часто встречаются константные маски (attention mask), позиционные кодировки, нормализационные параметры. Их свёртка уменьшает количество операций на каждый forward pass.


4. Dead code elimination (удаление мёртвого кода)

Dead code elimination (DCE) — удаление узлов, результаты которых не влияют на выход графа (не используются ни одним другим узлом, либо используются только другими мёртвыми узлами).

Как работает

  1. Компилятор помечает выходные узлы графа как "живые".
  2. Рекурсивно обходит граф от выходов к входам, помечая все узлы, от которых зависит выход, как "живые".
  3. Все непомеченные узлы удаляются.

Пример:

До:
  a = Input(x)
  b = Mul(a, 2.0)   # используется
  c = Sin(a)        # не используется
  d = Add(b, 1.0)
  output = d

После:
  a = Input(x)
  b = Mul(a, 2.0)
  d = Add(b, 1.0)
  output = d

Узел c удалён, так как его результат нигде не используется.

Влияние на LLM В процессе разработки модели могут оставаться "висячие" операции (например, отладочные выводы, неиспользуемые ветви). DCE очищает граф, уменьшая количество запускаемых ядер.


5. Другие важные оптимизации графа

ОптимизацияОписаниеПример в LLM
Operator fusionОбъединение нескольких последовательных операций в одно ядро (например, MatMul + Add + ReLU).Уменьшает launch overhead и улучшает локальность данных.
Memory planningПереиспользование памяти для тензоров, которые живут в непересекающиеся промежутки времени.Снижает пиковое потребление памяти при инференсе.
Algebraic simplificationУпрощение выражений: x * 0 → 0, x + 0 → x, x / 1 → x.Убирает лишние операции.
Common subexpression elimination (CSE)Если одно и то же выражение вычисляется дважды, результат сохраняется и переиспользуется.Например, повторное вычисление softmax в多头注意力.
Shape specializationФиксация размеров тензоров на этапе компиляции, что позволяет генерировать более эффективный код.Особенно полезно для моделей с фиксированным размером входа.

6. Влияние на launch overhead

Launch overhead — это время, затрачиваемое на запуск каждого ядра на GPU (передача команд, установка параметров, синхронизация). Для маленьких батчей (batch size = 1) overhead может составлять значительную часть общего времени инференса.

Graph optimization уменьшает launch overhead двумя способами:

  • Сокращение числа ядер (constant folding, DCE, fusion).
  • Увеличение размера каждого ядра (fusion объединяет мелкие операции в одно крупное ядро).

Пример: в LLM с 32 слоями и 4 операциями на слой без оптимизации будет 128 запусков ядер. После fusion (объединение в 1 ядро на слой) — 32 запуска. После constant folding и DCE — ещё меньше.


7. Инструменты для graph optimization

ИнструментФреймворкОсобенности
XLA (Accelerated Linear Algebra)TensorFlow, JAXКомпилирует подграфы в оптимизированный код, использует HLO IR.
TVMApache TVMПолный стек: от импорта моделей до генерации кода для разных устройств.
torch.compilePyTorch 2.xИспользует TorchDynamo для захвата графа и TorchInductor для генерации кода.
MLIRLLVMМногоуровневое IR, позволяет применять оптимизации на разных уровнях абстракции.
ONNX RuntimeONNXОптимизирует граф ONNX, включая constant folding и DCE.

8. Пример кода: torch.compile с constant folding

import torch
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 5)
        self.constant = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

    def forward(self, x):
        # constant folding: self.constant * 2.0 можно вычислить заранее
        scale = self.constant * 2.0
        out = self.linear(x)
        out = out + scale
        return out

model = SimpleModel()
x = torch.randn(1, 10)

# Без компиляции
out1 = model(x)

# С компиляцией (torch.compile выполнит constant folding, DCE и fusion)
compiled_model = torch.compile(model, mode="reduce-overhead")
out2 = compiled_model(x)

print(torch.allclose(out1, out2))  # True

При компиляции torch.compile:

  • Выражение self.constant * 2.0 будет вычислено один раз и заморожено как константа.
  • Если scale используется только в одном месте, лишние операции удаляются.
  • linear и add могут быть слиты в одно ядро.

9. Практические соображения и trade-offs

  • Время компиляции Оптимизации графа требуют дополнительного времени на этапе компиляции (compile time). Для production-систем часто используют кэширование откомпилированных графов.
  • Динамические формы Если размеры входных тензоров меняются, некоторые оптимизации (shape specialization) становятся неприменимы. Приходится использовать динамические shape или компилировать несколько вариантов.
  • Совместимость Не все операции поддерживаются компиляторами. Например, сложные control flow (if, loops) могут привести к падению в eager mode.
  • Отладка Оптимизированный граф сложнее отлаживать, так как исходные имена операций теряются.

Пет-проект для закрепления

Задача Реализовать простой компилятор вычислительного графа на Python, который выполняет constant folding и dead code elimination для арифметических выражений.

Инструменты Python, библиотека networkx для работы с графами.

Шаги:

  1. Определить класс Node с полями: op (тип операции), inputs (список ссылок на другие узлы), value (для констант), output (булево, является ли выходом графа).
  2. Построить граф для выражения: ((3 + 4) * x) + (5 * 0). Здесь 3+4 — константное сложение, 5*0 — умножение на ноль.
  3. Реализовать constant_folding(graph): найти узлы, где все входы — константы, вычислить, заменить на константу.
  4. Реализовать dead_code_elimination(graph): пометить живые узлы от выходов, удалить непомеченные.
  5. Вывести граф до и после оптимизаций, показать количество узлов.
  6. Измерить время выполнения на случайных данных (сравнить с неоптимизированным графом).

Ожидаемый результат

  • После constant folding: 7 * x + 0.
  • После DCE: 7 * x (умножение на 0 удалено, так как результат 0 не используется).
  • Ускорение выполнения за счёт уменьшения числа операций.

Связь с другими вопросами

ВопросТема
320Как работают LLM компиляторы (torch.compile, XLA)?
322Что такое operator fusion и как он ускоряет инференс?
323Как quantization взаимодействует с компиляцией графа?
324Какие техники memory optimization применяются в LLM инференсе?
325Как профилировать и выявлять узкие места в графе модели?

Навигация