Aivaro
  • Оглавление
  • Вопросы
  • Практика
  • Вики
  • Материалы сообщества
  • Тесты
  • Поиск
✈Telegram @ai_varo
RUEN中文
…
Оглавление/Вопросы/#958

Как объединить несколько LoRA адаптеров для разных доменов (LoRA Hub)? Проблема конфликта весов и пути её решения.

Краткий тезис

При дообучении большой языковой модели под несколько доменов (например, юриспруденция и медицина) часто обучают отдельные LoRA-адаптеры для каждой задачи. Возникает потребность объединить их в один мульти-доменный адаптер без повторного обучения. Простейшее сложение матриц обновлений приводит к конфликту весов, когда одинаковые веса получают противоположные обновления от разных доменов. Решения включают взвешенное суммирование, поэлементные операции (product, average), Ties-Merging и использование обучаемой gating-сети (LoRA Hub), которая динамически смешивает адаптеры на уровне токенов или слоёв.

2. Конфликт: домены могут требовать противоположных обновлений

Представим, что один домен требует увеличить значение определённого веса (w_{ij}) (например, для активации нейрона, отвечающего за юридические термины), а второй – уменьшить это же значение. Тогда ([Delta](/wiki/Delta) w_{ij}^{(1)} > 0), ([Delta](/wiki/Delta) w_{ij}^{(2)} < 0), и сумма может дать малое или нулевое изменение, нейтрализуя вклад обоих адаптеров. В результате модель не будет эффективно обрабатывать ни один из доменов.

Формально:
Пусть ( [Delta](/wiki/Delta) W_1 = B_1A_1) и ([Delta](/wiki/Delta) W_2 = B_2A_2). Для определённой позиции ((i,j)):
[ [Delta](/wiki/Delta) w_{ij}^{(1)} + [Delta](/wiki/Delta) w_{ij}^{(2)} \approx 0 \quad [text](/wiki/text){(противоположные знаки)}. ]

Это проявляется особенно остро, когда ранги адаптеров малы ((r \ll d,k)), и каждый элемент ([Delta](/wiki/Delta) W) влияет на множество связей. Конфликт неизбежен при совместном использовании независимо обученных адаптеров без учёта их взаимодействия.


3. Решения: взвешенная сумма, product, Ties-Merging

3.1. Взвешенное суммирование (Weighted Sum)

Вместо равного сложения вводятся веса (\lambda_1, \lambda_2) (часто (\sum \lambda_i = 1)):
[ W_{[text](/wiki/text){merged}} = W + \lambda_1 [Delta](/wiki/Delta) W_1 + \lambda_2 [Delta](/wiki/Delta) W_2. ]

Веса можно задать эмпирически (например, (\lambda_1=0.7), (\lambda_2=0.3)) или настраивать на небольшом валидационном наборе. Метод не решает проблему противоположных знаков, но позволяет ослабить вклад конфликтующего адаптера.

3.2. Поэлементное произведение (Product) или среднее (Average)

Вместо суммы можно использовать поэлементное произведение:
[ W_{[text](/wiki/text){merged}} = W + [Delta](/wiki/Delta) W_1 \odot [Delta](/wiki/Delta) W_2. ]
Это экстремально: если хотя бы один адаптер имеет малые значения в позиции, произведение становится ещё меньше. Обычно применяется лишь в сочетании с нелинейностями (например, после softmax). Более практично – поэлементное среднее:
[ W_{[text](/wiki/text){merged}} = W + \frac{[Delta](/wiki/Delta) W_1 + [Delta](/wiki/Delta) W_2}{2}. ]
Среднее уменьшает амплитуду обновлений, но не устраняет конфликт знаков.

3.3. Ties-Merging (Yadav et al., 2023)

Известный метод из области слияния моделей, адаптированный для LoRA. Работает в три этапа:

  1. Trim – обнулить наименьшие по абсолютной величине компоненты каждого ([Delta](/wiki/Delta) W_i) (оставить топ-k%).
  2. Sign agreement – определить для каждой позиции знак большинства: если у двух адаптеров совпадает, оставляем его, иначе – обнуляем (disjoint sign → 0).
  3. Merge – усреднить оставшиеся ненулевые значения (с учётом знака большинства).

Этот подход явно решает конфликт: когда знаки противоположны, вклад адаптера в данной позиции отбрасывается. Результат – модель, которая сохраняет специализацию в непротиворечивых областях.

Реализация (скелет):

def ties_merge(delta_Ws, top_k_ratio=0.3):
    n = len(delta_Ws)
    mask = torch.stack([torch.abs(d) for d in delta_Ws])    # (n, d, k)
    mask = (mask >= torch.quantile(mask, 1-top_k_ratio, dim=0)).float()
    
    sign_sum = torch.sum(torch.stack([torch.sign(d)*m for d,m in zip(delta_Ws, mask)]), dim=0)
    majority_sign = torch.sign(sign_sum)
    agreement = (torch.abs(sign_sum) >= 1)  # хотя бы один адаптер совпадает с большинством
    
    merged = torch.zeros_like(delta_Ws[0])
    for d in delta_Ws:
        merged += d * (torch.sign(d) == majority_sign).float() * agreement
    return merged / (torch.abs(delta_Ws).sum(dim=0).clamp(min=1e-8))

4. LoRAHub: обучение gating-сети

LoRA Hub (или LoRA Hub) – продвинутый подход, предложенный в работах по модульным нейронным сетям. Основная идея: не фиксировать веса смеси, а научить небольшую сеть (gating network) динамически выбирать, какой из адаптеров активировать для каждого входного токена (или слоя).

4.1. Архитектура LoRA Hub

Пусть имеется K адаптеров с матрицами ([Delta](/wiki/Delta) W_k = B_k A_k) (или отдельными A и B). Gating-сеть (G) принимает на вход скрытое представление (h) (из предыдущего слоя) и выводит вектор весов ( [alpha](/wiki/alpha) = [text](/wiki/text){softmax}(G(h)) \in \mathbb{R}^K). Тогда обновление для данного токена:
[ [Delta](/wiki/Delta) W_{[text](/wiki/text){hub}} = \sum_{k=1}^K \alpha_k \cdot [Delta](/wiki/Delta) W_k. ]

Gating-сеть обучается на смешанном датасете из всех доменов с целью минимизации языковой потери. При этом основные предобученные веса (W) и параметры адаптеров заморожены – обучаются только веса гейта (обычно один линейный слой + softmax, размерность мала: (d_{[text](/wiki/text){hidden}} \times K)).

4.2. Преимущества

  • Полностью устраняет конфликт весов на уровне отдельных ячеек за счёт динамического выбора, а не фиксированной суммы.
  • Позволяет одному токену использовать комбинацию сразу нескольких доменов (например, юридический термин с медицинским контекстом).
  • Легко дообучить под новый домен, добавив новый адаптер и расширив гейт.

4.3. Практическая реализация

Для каждого Linear слоя, к которому применяется LoRA, добавляется gating-модуль. На этапе инференса вычисляется ([alpha](/wiki/alpha)) и применяется взвешенная сумма обновлений. Так как вычислять (B_k A_k) для каждого токена накладно, можно предварительно вычислить ([Delta](/wiki/Delta) W_k) и хранить их, но это увеличивает память. Альтернатива – вычислять (B_k(A_k x)) или (A_k(B_k^T x)) в зависимости от порядка (weight tying). На практике часто используют два пути: для каждого адаптера считают (B_k (A_k x)), затем взвешивают.

class LoRAHubLayer(nn.Module):
    def __init__(self, base_linear, lora_list, hidden_size, num_loras):
        super().__init__()
        self.base = base_linear
        self.lora_list = nn.ModuleList(lora_list)  # каждый LoRA с B, A
        self.gate = nn.Linear(hidden_size, num_loras)
    
    def forward(self, x):
        base_out = self.base(x)
        alpha = torch.softmax(self.gate(x), dim=-1)  # (batch, seq_len, num_loras)
        lora_outs = torch.stack([lora(x) for lora in self.lora_list], dim=-1)  # (b, seq, out_dim, num_loras)
        lora_out = torch.sum(lora_outs * alpha.unsqueeze(-2), dim=-1)
        return base_out + lora_out

Пет-проект для закрепления

Задача: Объединить два LoRA-адаптера, дообученных на юридических и медицинских текстах, в один адаптер, используя метод Ties-Merging, затем сравнить качество на датасете смешанного домена по сравнению с наивным сложением.

Инструменты:

  • Python, PyTorch, HuggingFace Transformers, PEFT библиотека.
  • Датасеты: примеры юридических и медицинских инструкций (например, legal-qa, medical-meadow).
  • Предобученная LLM: Llama 2 7B (можно уменьшенная версия).

Шаги:

  1. Загрузить базовую модель и два обученных LoRA-адаптера (скачать готовые или обучить самостоятельно на 5000 примеров каждого домена).
  2. Извлечь матрицы ([Delta](/wiki/Delta) W) для каждого целевого слоя из адаптеров (через adapter.get_base_layer().weight и матрицы A, B).
  3. Реализовать функцию Ties-Merging (trim 30%, majority sign).
  4. Слить адаптеры и сохранить новую LoRA-конфигурацию (параметры (B_{[text](/wiki/text){merged}}, A_{[text](/wiki/text){merged}}) – если требуется низкоранговое представление, придётся сделать аппроксимацию SVD).
  5. Оценить perplexity и accuracy на задаче Q&A для трёх вариантов: только юридический адаптер, только медицинский, наивная сумма, Ties-Merged.
  6. Построить таблицу результатов.

Ожидаемый результат:

  • Ties-Merging покажет меньший конфликт: перплексия на юридических вопросах останется близка к специализированному адаптеру, а на медицинских – не упадёт катастрофически.
  • Наивная сумма даст среднюю перплексию, но может проиграть на обоих доменах из-за взаимной нейтрализации.

Связь с другими вопросами

ВопросТема
40Объединение моделей: слияние весов, FedAvg, model soup

Навигация

  • Предыдущий: 957
  • Следующий: 959
  • Индекс: 00. Индекс разборов zed LoRA) и зачем он нужен?|959]]
  • Индекс: 00. Индекс разборов