English translation is not available yet. Showing Russian content.

Как работает XLA (Accelerated Linear Algebra) для LLM на TPU?

Краткий тезис

XLA (Accelerated Linear Algebra) — это JIT-компилятор (Just-In-Time), который преобразует graph|вычислительный граф модели в высокооптимизированный машинный код для TPU (Tensor Processing Unit). Для LLM XLA даёт 2–3-кратное ускорение по сравнению с eager-режимом PyTorch за счёт слияния операций (fusion), статической типизации тензоров (static shapes) и эффективного планирования памяти. Без XLA современные LLM на TPU работали бы значительно медленнее из-за накладных расходов на запуск отдельных ядер (overhead|launch overhead|launch overhead|launch overhead|launch kernel overhead|launch overhead|launch overhead).


1. Что такое XLA и зачем он нужен для LLM

XLA — это компилятор для линейной алгебры, разработанный Google. Он входит в состав TensorFlow и JAX, а также доступен через torch.compile в PyTorch. Основная цель — автоматически оптимизировать выполнение операций, особенно на специализированном оборудовании (TPU, GPU).

Для LLM, которые состоят из тысяч операций (матричные умножения, softmax, layer norm), XLA решает две ключевые проблемы:

  • Накладные расходы на запуск ядер (kernel launch overhead): в eager-режиме каждая операция запускается отдельно, что приводит к задержкам.
  • Неоптимальное использование памяти: XLA может перепланировать размещение тензоров, уменьшая пиковое потребление памяти и ускоряя обмен с HBM (High Bandwidth Memory).

2. Основные этапы компиляции XLA

XLA работает в несколько этапов:

  1. Построение HLO (High Level Operations) — промежуточное представление графа, похожее на IR компилятора. HLO описывает все операции модели в виде направленного ациклического графа (DAG).
  2. Оптимизации HLO — применяются проходы:
    • Fusion — объединение нескольких операций в одно ядро.
    • Constant folding — вычисление константных выражений на этапе компиляции.
    • Dead code elimination — удаление неиспользуемых вычислений.
    • Memory planning — анализ времени жизни тензоров и переиспользование памяти.
  3. Генерация кода для TPUHLO транслируется в инструкции для TPU-ядер (TensorCores) и Matrix Units (MXU). Для TPU v4 это включает:
    • Разбиение графа на systolic array вычисления.
    • Планирование загрузки весов из HBM в VMEM (Vector Memory).
    • Генерация кода для scalar unit и vector unit.

3. Статические формы (static shapes) — ключевое требование

XLA требует, чтобы все тензоры имели фиксированные размерности на этапе компиляции. Это позволяет:

  • Заранее выделить память под все промежуточные результаты.
  • Оптимизировать layout данных (например, транспонировать матрицы для эффективного доступа).
  • Сгенерировать код без проверок размеров во время выполнения.

Для LLM это означает, что batch size, sequence length и hidden size должны быть известны до компиляции. Если форма меняется (например, разная длина последовательности в инференсе), XLA приходится перекомпилировать граф — это называется recompilation overhead. Чтобы избежать этого, используют padding до максимальной длины или dynamic batching.


4. Операции слияния (fusion) — уменьшение числа kernel launches

Одна из главных оптимизаций XLAkernel fusion. Вместо того чтобы запускать отдельные ядра для каждой операции (например, matmul, add, relu), XLA объединяет их в одно ядро.

Пример без fusion (eager):

z = matmul(x, W)   # kernel 1
z = z + b          # kernel 2
z = relu(z)        # kernel 3

Каждый kernel запускается отдельно, данные передаются через HBM.

С fusion (XLA):

fused_kernel(x, W, b):
    z = matmul(x, W)
    z = z + b
    z = relu(z)
    return z

Одно ядро читает x, W, b из HBM один раз, выполняет все операции на MXU/vector unit и записывает результат.

Для LLM с тысячами операций fusion сокращает число запусков ядер в десятки раз, что даёт основной прирост скорости.


5. Планирование памяти (memory planning)

XLA анализирует время жизни каждого тензора в графе и переиспользует память между тензорами, которые не пересекаются по времени. Это особенно важно для LLM, где промежуточные активации занимают гигабайты.

Пример: в трансформере после вычисления attention scores (QK^T) и softmax, исходные Q и K больше не нужны. XLA может выделить память для scores поверх Q или K, экономя HBM.

На TPU память делится на HBM (высокая ёмкость, ~32 ГБ на чип) и VMEM (быстрая, но маленькая, ~16 МБ). XLA старается размещать часто используемые данные (веса, кэш KV) в VMEM, а остальное — в HBM с DMA-передачами.


6. Использование TPU Matrix Units (MXU) для матричных умножений

Сердце TPU — Matrix Multiply Unit (MXU), работающий как systolic array (систолический массив). XLA генерирует код, который:

  • Разбивает матричные умножения на блоки (tiles), помещающиеся в MXU (например, 128×128).
  • Загружает блоки из VMEM в регистры MXU.
  • Выполняет умножение за один такт (для TPU v4 — до 128×128×16 операций за цикл).
  • Суммирует частичные результаты.

Для LLM, где доминируют операции matmul (внимание, FFN), XLA максимизирует загрузку MXU, избегая простоев.


7. Сравнение XLA с eager-режимом PyTorch

ХарактеристикаEager PyTorchXLA (на TPU)
Запуск операцийКаждая операция — отдельный kernel launchВесь граф компилируется в одно или несколько ядер
Управление памятьюВыделение/освобождение на летуПланирование заранее, переиспользование
Поддержка динамических формДа (любые размеры)Требует статических форм (или рекомпиляцию)
ОтладкаЛегко (пошаговое выполнение)Сложно (граф оптимизирован, трудно сопоставить с исходным кодом)
ПроизводительностьБазовая2–3x быстрее на TPU (до 10x на некоторых операциях)
ГибкостьВысокая (условные операторы, циклы)Ограничена (граф должен быть статическим)

8. Ограничения XLA

  • Dynamic shapes — если форма тензора меняется между вызовами, XLA перекомпилирует граф, что может занимать секунды. Для LLM инференса это критично: разные длины последовательностей вызывают рекомпиляцию. Решение — padding до максимальной длины или использование XLA:GPU с поддержкой динамики (но на TPU это ограничение жёстче).
  • Compilation overhead — первая компиляция может занимать минуты. Для продакшена используют persistent compilation (сохранение скомпилированного кода).
  • Сложность отладки — оптимизированный код трудно сопоставить с исходным. Инструменты: XLA HLO dump, TensorBoard для визуализации графа.
  • Не все операции поддерживаются — редкие операции (например, custom CUDA kernels) могут не компилироваться под TPU.

9. XLA в экосистеме: JAX, TensorFlow, PyTorch

  • JAX — наиболее тесно интегрирован с XLA. Каждая функция, переданная в jax.jit, компилируется XLA. JAX автоматически выводит статические формы и поддерживает pjit для распределённого выполнения на нескольких TPU.
  • TensorFlow — XLA используется через tf.function (autograph) и tf.xla.experimental.compile. В TF 2.x XLA включён по умолчанию для TPU.
  • PyTorchtorch.compile с бэкендом "xla" (через torch_xla). Позволяет компилировать отдельные участки модели. Для LLM часто используют torch_xla с xla_model для распределённого обучения на TPU.

10. Пример: компиляция LLM inference графа на TPU

Рассмотрим упрощённый инференс одного слоя трансформера (attention + FFN) на TPU через JAX:

import jax
import jax.numpy as jnp

def transformer_layer(x, W_q, W_k, W_v, W_o, W_ff1, W_ff2):
    # Attention
    q = jnp.dot(x, W_q)
    k = jnp.dot(x, W_k)
    v = jnp.dot(x, W_v)
    attn_scores = jnp.dot(q, k.T) / jnp.sqrt(k.shape[-1])
    attn_weights = jax.nn.softmax(attn_scores, axis=-1)
    attn_out = jnp.dot(attn_weights, v)
    attn_out = jnp.dot(attn_out, W_o)
    # FFN
    ff = jnp.dot(attn_out, W_ff1)
    ff = jax.nn.relu(ff)
    ff = jnp.dot(ff, W_ff2)
    return ff

# Компиляция с XLA
compiled_layer = jax.jit(transformer_layer, static_argnums=(0,))  # x — статическая форма
# Вызов
output = compiled_layer(x, W_q, W_k, W_v, W_o, W_ff1, W_ff2)

XLA сольёт все matmul, softmax и relu в несколько оптимизированных ядер, перепланирует память и загрузит веса в VMEM. На TPU v4 такой слой выполняется за ~10 мкс (против ~30 мкс в eager-режиме).


11. Пет-проект для закрепления

Задача Сравнить производительность eager-режима PyTorch и XLA (через torch.compile с бэкендом "xla") на небольшом трансформере (2 слоя, 4 головы, d_model=128) на TPU (можно использовать Google Colab с TPU).

Инструменты PyTorch, torch_xla, timeit, torch.utils.benchmark.

Шаги:

  1. Установить torch_xla и настроить TPU runtime.
  2. Написать класс SimpleTransformer с attention и FFN.
  3. Запустить инференс в eager-режиме (без компиляции) — замерить latency для batch_size=1, seq_len=128.
  4. Обернуть forward в torch.compile(backend="xla") и запустить снова (первый запуск — компиляция, второй — замер).
  5. Сравнить время выполнения, пиковое потребление памяти (через torch.cuda.max_memory_allocated для TPU — xla_device).
  6. Построить таблицу результатов.

Ожидаемый результат XLA-версия должна быть в 2–3 раза быстрее, а пиковое потребление памяти — на 20–30% ниже.


12. Связь с другими вопросами

ВопросТема
318Архитектура TPU (ядра, MXU, HBM)
320Pallas — низкоуровневые ядра для TPU
317JAX vs PyTorch для TPU
316Модельный параллелизм на TPU (sharding)
321Оптимизация памяти LLM на TPU (flash attention)
315Сравнение GPU и TPU для LLM

13. Навигация


Навигация