Как работает Mixture of Experts (MoE) внутри LLM (спарсинг активации)?

Краткий тезис

Mixture of Experts (MoE) — это архитектурный приём, при котором в Transformer вместо одного плотного Feed-Forward Network (FFN) используется несколько параллельных FFN («экспертов»). Router (линейный слой с softmax) для каждого токена предсказывает веса экспертов, активируются только top-k (обычно k=2) экспертов, а выход формируется как взвешенная сумма их выходов. Это позволяет иметь огромное общее число параметров (например, 47B у Mixtral 8x7B), но при инференсе выполнять вычисления, эквивалентные модели с 13B параметров, за счёт спарсинг активации — каждый токен использует лишь малую долю параметров.

1. Термин: Mixture of Experts (MoE)

MoE — это метод ансамблирования нейронных сетей, где несколько «экспертов» (подсетей) специализируются на разных подзадачах, а gating network (router) динамически выбирает, каких экспертов активировать для каждого входного примера. В контексте LLM экспертами обычно являются Feed-Forward Network (FFN) слои, а gating — это обучаемый линейный слой с softmax.

История Идея MoE восходит к работе Jacobs et al. (1991) «Adaptive Mixtures of Local Experts», но в NLP её популяризовали Shazeer et al. (2017) в статье «Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer». Позже MoE была применена в Switch Transformer (Fedus et al., 2021) и Mixtral 8x7B (Mistral AI, 2023).

Зачем это нужно Масштабирование LLM традиционным способом (увеличение числа параметров) ведёт к квадратичному росту вычислительных затрат. MoE позволяет наращивать число параметров без пропорционального увеличения FLOPs, что критично для обучения и инференса больших моделей.

2. Проблема: стандартный FFN в Transformer

В классическом Transformer каждый блок содержит self-attention и FFN. FFN состоит из двух линейных слоёв с нелинейностью (обычно ReLU или SiLU):

FFN(x) = W2 * activation(W1 * x + b1) + b2

Размерность скрытого слоя (hidden_dim) обычно в 4 раза больше размерности модели (d_model). Например, для LLaMA-7B: d_model=4096, hidden_dim=11008. Это означает, что для каждого токена активируются все 11008 нейронов — плотная активация. Если мы хотим увеличить ёмкость модели, приходится увеличивать hidden_dim, что линейно увеличивает FLOPs.

Проблема Не все токены требуют одинаковой вычислительной мощности. Некоторые токены (например, артикли, предлоги) могут быть обработаны с меньшими затратами, чем сложные семантические токены. MoE решает эту проблему, предоставляя набор специализированных экспертов и динамически выбирая только нужные.

3. Архитектура MoE

Типичный MoE-слой в LLM выглядит так:

Вход: x (вектор токена)
1. Router: r = softmax(W_r * x)  # W_r — матрица router'а размером (n_experts, d_model)
2. Top-k selection: выбираем k экспертов с наибольшими весами
3. Для каждого выбранного эксперта e:
   y_e = FFN_e(x)
4. Выход: y = sum(weight_e * y_e for e in selected)

Компоненты

  • Эксперты (Experts): Обычно это идентичные по архитектуре FFN, но с разными весами. Количество экспертов (n_experts) варьируется от 8 до 2048.
  • Router (Gating Network): Линейный слой без bias, выходной размер равен n_experts. После softmax получаем распределение вероятностей.
  • Top-k selection Выбираются k экспертов с наибольшими вероятностями. Остальным экспертам присваивается вес 0. k обычно равно 1 (Switch Transformer) или 2 (Mixtral).
  • Взвешенная сумма Выходы выбранных экспертов умножаются на соответствующие веса router'а и суммируются.

Пример кода на PyTorch (упрощённый):

import torch
import torch.nn as nn
import torch.nn.functional as F

class MoELayer(nn.Module):
    def __init__(self, d_model, hidden_dim, n_experts, k=2):
        super().__init__()
        self.n_experts = n_experts
        self.k = k
        # Эксперты: список FFN
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, d_model)
            ) for _ in range(n_experts)
        ])
        # Router
        self.router = nn.Linear(d_model, n_experts, bias=False)

    def forward(self, x):
        # x: (batch, seq_len, d_model)
        batch, seq, d = x.shape
        # Router логиты
        router_logits = self.router(x)  # (batch, seq, n_experts)
        router_weights = F.softmax(router_logits, dim=-1)

        # Top-k выбор
        topk_weights, topk_indices = torch.topk(router_weights, self.k, dim=-1)
        # topk_weights: (batch, seq, k), topk_indices: (batch, seq, k)

        # Нормализация весов (сумма по k = 1)
        topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)

        # Инициализация выхода
        output = torch.zeros_like(x)

        # Для каждого эксперта
        for expert_idx in range(self.n_experts):
            # Маска: где этот эксперт в top-k
            mask = (topk_indices == expert_idx).any(dim=-1)  # (batch, seq)
            if not mask.any():
                continue
            # Входы для этого эксперта
            expert_input = x[mask]
            # Выход эксперта
            expert_output = self.experts[expert_idx](expert_input)
            # Веса для этого эксперта
            # Находим позицию эксперта в top-k для каждого токена
            # Упрощённо: берём вес из topk_weights по индексу эксперта
            # (в реальной реализации используется gather)
            weight_mask = (topk_indices == expert_idx).float()  # (batch, seq, k)
            weights = (topk_weights * weight_mask).sum(dim=-1)  # (batch, seq)
            output[mask] += expert_output * weights[mask].unsqueeze(-1)

        return output, router_logits

4. Спарсинг активации (Sparsity Activation)

Спарсинг активации означает, что для каждого токена активируется только малая часть параметров модели. В MoE это достигается за счёт top-k выбора экспертов. Если у нас 8 экспертов и k=2, то активируется только 25% экспертов. Остальные 6 экспертов не участвуют в вычислениях для данного токена.

Ключевое следствие Общее число параметров модели (сумма всех экспертов + router) может быть огромным, но FLOPs (количество операций с плавающей точкой) на один токен определяется только k экспертами. Например, Mixtral 8x7B имеет 47B параметров (8 экспертов по 7B + общие слои), но при инференсе активируется только 2 эксперта, что даёт FLOPs, эквивалентные модели с 13B параметров.

Сравнение плотной и разреженной модели

ХарактеристикаПлотная модель (LLaMA-13B)MoE модель (Mixtral 8x7B)
Параметры13B47B
FLOPs на токен~13B~13B (2 эксперта из 8)
Память (веса)26 GB (fp16)94 GB (fp16)
Память (активации)~2 GB~2 GB (только 2 эксперта)
Скорость инференса1x~0.8-0.9x (из-за overhead router)

Важно Хотя FLOPs одинаковы, реальная скорость может быть ниже из-за дополнительных операций router'а и необходимости загружать все веса экспертов в память (даже неактивные). Однако при распределённом инференсе (эксперты на разных GPU) можно добиться высокой эффективности.

5. Router и Load Balancing

Проблема Router может научиться всегда выбирать одних и тех же экспертов (например, из-за неравномерности данных). Это приводит к тому, что некоторые эксперты перегружены, а другие не обучаются.

Решение Добавляют auxiliary loss (вспомогательную функцию потерь), которая штрафует за дисбаланс. Типичный подход — load balancing loss из Switch Transformer:

L_aux = α * n_experts * sum(f_i * P_i)

где:

  • f_i — доля токенов, назначенных эксперту i (фактическая загрузка)
  • P_i — средняя вероятность, которую router присваивает эксперту i (ожидаемая загрузка)
  • α — коэффициент (обычно 0.01)

Эта loss минимизируется вместе с основной loss (cross-entropy). Она стимулирует router распределять токены равномерно.

Другие техники балансировки

  • Expert Choice routing (Zhou et al., 2022): не токены выбирают экспертов, а эксперты выбирают токены.
  • Auxiliary loss с дифференцируемым top-k (например, soft top-k).
  • Добавление шума в router при обучении для улучшения исследования.

6. Примеры MoE моделей

МодельЧисло экспертовkПараметры (всего)FLOPs (эквивалент)Особенности
Switch Transformer (Base)6417B1.5BПервая крупная MoE в NLP
GLaM (Google, 2021)6421.2T64BОбучена на 1.6T токенов
Mixtral 8x7B (Mistral, 2023)8247B13BОткрытая, превосходит LLaMA-2 70B
DeepSeek-MoE (2024)64 (fine-grained)6145B21BShared expert + routed experts
Qwen2.5-MoE8214B4BОптимизирована для инференса

Mixtral 8x7B — наиболее известная открытая MoE модель. Она состоит из 8 экспертов, каждый размером 7B (но фактически 7B — это параметры только FFN, а attention и embedding общие). Router выбирает 2 эксперта на токен. Модель показывает качество, сравнимое с LLaMA-2 70B, при значительно меньших вычислительных затратах.

7. Преимущества MoE

  1. Масштабирование без пропорционального роста compute Можно увеличивать число экспертов, не увеличивая FLOPs на токен.
  2. Специализация экспертов Эксперты могут выучить разные паттерны (синтаксис, семантика, факты), что улучшает качество.
  3. Эффективное обучение При том же бюджете FLOPs MoE модели могут быть обучены на большем количестве токенов, что даёт лучшие результаты (см. Chinchilla scaling laws).
  4. Гибкость Можно легко добавлять новых экспертов после обучения (continual learning).

8. Недостатки и вызовы

  1. Память Все веса экспертов должны храниться в памяти (даже неактивные). Для Mixtral 8x7B нужно 94 GB в fp16, что требует нескольких GPU или специальных техник (offloading).
  2. Load balancing Необходимость дополнительной loss и риск коллапса router'а.
  3. Overhead коммуникации При распределённом обучении/инференсе эксперты могут быть размещены на разных устройствах, что требует передачи активаций между ними (all-to-all communication).
  4. Сложность обучения MoE модели более чувствительны к гиперпараметрам (learning rate, batch size), требуют стабилизации (например, gradient clipping).
  5. Инференс latency Router добавляет задержку, а при batch-обработке разные токены могут выбирать разных экспертов, что усложняет эффективную реализацию.

9. Связь с Agentic RAG

В контексте Agentic RAG MoE может быть использована не только внутри LLM, но и как архитектурный паттерн для выбора инструментов или стратегий. Например:

  • Router как агент Вместо выбора экспертов-FFN, router может выбирать между разными retrieval-стратегиями (sparse vs dense), разными базами знаний или разными LLM.
  • Multi-agent системы Каждый агент может быть «экспертом» (специализирован на определённом домене), а router (оркестратор) направляет запрос к нужному агенту.
  • Tool use MoE-подобный подход: для каждого запроса выбирается top-k инструментов из множества доступных.

Таким образом, понимание MoE помогает проектировать эффективные агентные системы с динамическим выбором подзадач.

Пет-проект для закрепления

Задача Реализовать простую MoE-модель на PyTorch и сравнить её с плотной моделью на задаче классификации текстов (например, SST-2).

Инструменты PyTorch, Hugging Face Datasets, Weights & Biases (опционально).

Шаги:

  1. Создать датасет Загрузить SST-2 (sentiment analysis) из datasets. Разделить на train/validation.
  2. Определить плотную модель Transformer с одним FFN (d_model=128, hidden_dim=512, 4 слоя).
  3. Определить MoE модель Те же 4 слоя, но FFN заменён на MoE с 8 экспертами, k=2. Размер каждого эксперта: hidden_dim=512.
  4. Обучить обе модели Использовать AdamW, learning rate 1e-4, batch size 32, 5 эпох. Для MoE добавить load balancing loss (α=0.01).
  5. Сравнить
    • Число параметров (model.summary()).
    • FLOPs на один токен (можно оценить через thop.profile).
    • Accuracy на validation.
    • Время инференса на CPU/GPU.
  6. Визуализировать загрузку экспертов Построить гистограмму, сколько токенов попало к каждому эксперту.

Ожидаемый результат

  • MoE модель будет иметь в 4-5 раз больше параметров, но FLOPs будут примерно одинаковы.
  • Accuracy может быть немного выше у MoE (за счёт специализации).
  • Гистограмма загрузки экспертов должна быть относительно равномерной (если load balancing работает).

Связь с другими вопросами

ВопросТема
679Архитектура Agentic RAG: обзор компонентов
681Routing в RAG: как выбирать источники знаний
682Tool use: интеграция внешних инструментов
683Multi-agent системы: координация агентов
684Self-RAG: рефлексия и самокоррекция
685Corrective RAG: исправление ошибок retrieval

Навигация