mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests
- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples - Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components - Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture - Add thorough unit tests: * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc. * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check - All code and tests in compliance with recent scientific and engineering standards.
This commit is contained in:
211
llm/src/llm/core/mixtral_decoder.py
Normal file
211
llm/src/llm/core/mixtral_decoder.py
Normal file
@@ -0,0 +1,211 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
from llm.core.group_query_attention import GroupedQueryAttention
|
||||||
|
from llm.core.moe import MoE
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
class MixtralDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
|
||||||
|
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
|
||||||
|
|
||||||
|
Архитектура блока:
|
||||||
|
------------------
|
||||||
|
- RMSNorm -> Grouped Query Attention (GQA)
|
||||||
|
- skip-connection
|
||||||
|
- RMSNorm -> MoE (SwiGLU-эксперты)
|
||||||
|
- skip-connection
|
||||||
|
|
||||||
|
Для входа `x` проходит:
|
||||||
|
1. norm1_out = RMSNorm(x)
|
||||||
|
2. attention, kv_caches = GQA(norm1_out, ...)
|
||||||
|
3. out = attention + x # residual connection
|
||||||
|
4. norm2_out = RMSNorm(out)
|
||||||
|
5. ffn_out = MoE(norm2_out)
|
||||||
|
6. return (ffn_out + out, kv_caches)
|
||||||
|
|
||||||
|
Теоретическая мотивация:
|
||||||
|
------------------------
|
||||||
|
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
|
||||||
|
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
|
||||||
|
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
|
||||||
|
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
num_q_heads : int
|
||||||
|
Число query-голов в attention.
|
||||||
|
num_kv_heads : int
|
||||||
|
Число key-value голов (группировка ключей/values).
|
||||||
|
emb_size : int
|
||||||
|
Скрытый размер эмбеддинга.
|
||||||
|
head_size : int
|
||||||
|
Размерность одной головы (emb_size // num_q_heads).
|
||||||
|
max_seq_len : int
|
||||||
|
Максимальная поддерживаемая длина последовательности.
|
||||||
|
num_experts : int
|
||||||
|
Количество «экспертов» (MoE).
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько одновременно экспертов активируется для одного токена.
|
||||||
|
window_size : int
|
||||||
|
Размер окна внимания (используется для efficient attention).
|
||||||
|
rope : RoPE
|
||||||
|
Реализация позиционного кодирования RoPE.
|
||||||
|
dropout : float
|
||||||
|
Вероятность Dropout для регуляризации.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> decoder = MixtralDecoder(... параметры ...)
|
||||||
|
>>> x = torch.randn(batch, seq, emb_size)
|
||||||
|
>>> out, cache = decoder(x, mask=None, use_cache=True)
|
||||||
|
>>> out.shape
|
||||||
|
|
||||||
|
Литература и ссылки:
|
||||||
|
--------------------
|
||||||
|
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||||
|
- Mistral paper: https://arxiv.org/abs/2310.06825
|
||||||
|
- GQA: https://arxiv.org/abs/2305.14236
|
||||||
|
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_q_heads: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
num_experts: int,
|
||||||
|
top_k_experts: int,
|
||||||
|
window_size: int,
|
||||||
|
rope: RoPE,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Конструктор декодерного блока MixtralDecoder.
|
||||||
|
|
||||||
|
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
|
||||||
|
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_q_heads : int
|
||||||
|
Количество голов внимания (queries) в механизме GroupedQueryAttention.
|
||||||
|
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
|
||||||
|
num_kv_heads : int
|
||||||
|
Количество групп ключей/значений (key-value heads) для GQA.
|
||||||
|
Позволяет балансировать производительность и память.
|
||||||
|
emb_size : int
|
||||||
|
Размерность эмбеддингового пространства внутри слоя (hidden).
|
||||||
|
head_size : int
|
||||||
|
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально поддерживаемая длина токенизированной последовательности.
|
||||||
|
num_experts : int
|
||||||
|
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
|
||||||
|
window_size : int
|
||||||
|
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
|
||||||
|
rope : RoPE
|
||||||
|
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> decoder = MixtralDecoder(
|
||||||
|
... num_q_heads=8,
|
||||||
|
... num_kv_heads=2,
|
||||||
|
... emb_size=256,
|
||||||
|
... head_size=32,
|
||||||
|
... max_seq_len=1024,
|
||||||
|
... num_experts=4,
|
||||||
|
... top_k_experts=2,
|
||||||
|
... window_size=128,
|
||||||
|
... rope=rope_module,
|
||||||
|
... dropout=0.05
|
||||||
|
... )
|
||||||
|
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._heads = GroupedQueryAttention(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
window_size=window_size,
|
||||||
|
rope=rope,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = MoE(
|
||||||
|
emb_size=emb_size,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k_experts=top_k_experts,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._norm1 = RMSNorm(emb_size)
|
||||||
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через декодерный блок MixtralDecoder.
|
||||||
|
|
||||||
|
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
|
||||||
|
- нормализацию (RMSNorm),
|
||||||
|
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
|
||||||
|
- остаточное сложение (residual connection),
|
||||||
|
- повторную нормализацию,
|
||||||
|
- feed-forward блок на основе Mixture-of-Experts (MoE),
|
||||||
|
- финальное остаточное сложение.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
|
||||||
|
mask : torch.Tensor, optional
|
||||||
|
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
|
||||||
|
cache : list, optional
|
||||||
|
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
Tuple[torch.Tensor, Any]:
|
||||||
|
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
|
||||||
|
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
|
||||||
|
>>> out.shape # [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
|
||||||
|
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
|
||||||
|
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
|
||||||
|
|
||||||
|
"""
|
||||||
|
norm1_out = self._norm1(x)
|
||||||
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
|
out = attention + x
|
||||||
|
|
||||||
|
norm2_out = self._norm2(out)
|
||||||
|
ffn_out = self._ff(norm2_out)
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
return (ffn_out + out, kv_caches)
|
||||||
|
else:
|
||||||
|
return (ffn_out + out, None)
|
||||||
@@ -4,6 +4,58 @@ import torch.nn.functional as F
|
|||||||
from llm.core.swi_glu import SwiGLU
|
from llm.core.swi_glu import SwiGLU
|
||||||
|
|
||||||
class MoE(nn.Module):
|
class MoE(nn.Module):
|
||||||
|
"""
|
||||||
|
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
|
||||||
|
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
|
||||||
|
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
|
||||||
|
|
||||||
|
Архитектурная схема:
|
||||||
|
---------------------
|
||||||
|
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
|
||||||
|
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
|
||||||
|
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
|
||||||
|
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
|
||||||
|
- Dropout применяется к итоговому выходу.
|
||||||
|
|
||||||
|
Математика (коротко):
|
||||||
|
---------------------
|
||||||
|
Пусть X ∈ R^{BxSxD} — вход,
|
||||||
|
E — число экспертов,
|
||||||
|
K — число активируемых экспертов на токен.
|
||||||
|
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
|
||||||
|
Для каждого токена:
|
||||||
|
y_j = Expert_j(x)
|
||||||
|
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
|
||||||
|
Output: Y ∈ R^{BxSxD}
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
emb_size : int
|
||||||
|
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
|
||||||
|
num_experts : int
|
||||||
|
Общее число экспертов внутри слоя MoE.
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Dropout к выходу агрегатора.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||||
|
>>> x = torch.randn(4, 16, 512)
|
||||||
|
>>> y = moe(x)
|
||||||
|
>>> y.shape # torch.Size([4, 16, 512])
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
|
||||||
|
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
|
||||||
|
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
emb_size: int,
|
emb_size: int,
|
||||||
@@ -11,7 +63,51 @@ class MoE(nn.Module):
|
|||||||
top_k_experts: int,
|
top_k_experts: int,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Конструктор слоя MoE (Mixture of Experts).
|
||||||
|
|
||||||
|
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
|
||||||
|
который будет для каждого токена определять наиболее релевантных экспертов.
|
||||||
|
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
emb_size : int
|
||||||
|
Размерность входных и выходных векторов (embedding size).
|
||||||
|
Определяет, над каким пространством признаков будет работать роутер и эксперты.
|
||||||
|
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
|
||||||
|
|
||||||
|
num_experts : int
|
||||||
|
Общее количество экспертов в слое MoE.
|
||||||
|
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
|
||||||
|
Пример: 8, 16, 32, 64.
|
||||||
|
|
||||||
|
top_k_experts : int
|
||||||
|
Сколько экспертов одновременно будет обрабатывать каждый токен.
|
||||||
|
Обычно 2–8. Меньшее значение — выше разреженность, больше экономия вычислений.
|
||||||
|
|
||||||
|
dropout : float, по умолчанию 0.1
|
||||||
|
Вероятность зануления значений на выходе после агрегации откликов экспертов.
|
||||||
|
Используется для регуляризации (борьбы с переобучением).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
|
||||||
|
>>> print(moe)
|
||||||
|
MoE( ... )
|
||||||
|
|
||||||
|
Теория:
|
||||||
|
-------
|
||||||
|
Слой строит:
|
||||||
|
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
|
||||||
|
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
|
||||||
|
|
||||||
|
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
|
||||||
|
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if top_k_experts > num_experts:
|
||||||
|
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
|
||||||
self._num_experts = num_experts
|
self._num_experts = num_experts
|
||||||
self._top_k_experts = top_k_experts
|
self._top_k_experts = top_k_experts
|
||||||
|
|
||||||
@@ -23,6 +119,47 @@ class MoE(nn.Module):
|
|||||||
self._dropout = nn.Dropout(dropout)
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через слой MoE.
|
||||||
|
|
||||||
|
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
|
||||||
|
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
|
||||||
|
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
|
||||||
|
|
||||||
|
Математически:
|
||||||
|
--------------
|
||||||
|
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
|
||||||
|
router_logits = Linear(x) ∈ ℝ^{batch, seq, num_experts}
|
||||||
|
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
|
||||||
|
3. Каждый эксперт обрабатывает только свой поднабор токенов.
|
||||||
|
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
|
||||||
|
5. На результат применяется dropout для регуляризации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
|
||||||
|
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
torch.Tensor :
|
||||||
|
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
|
||||||
|
с учетом softmax-весов маршрутизатора и dropout'а.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> y = moe(x)
|
||||||
|
>>> print(y.shape)
|
||||||
|
torch.Size([batch_size, seq_length, emb_size])
|
||||||
|
|
||||||
|
Примечание:
|
||||||
|
-----------
|
||||||
|
- Каждый токен чаще всего активирует только подмножество экспертов.
|
||||||
|
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
|
||||||
|
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
|
||||||
|
|
||||||
|
"""
|
||||||
batch_size, seq_len, emb_size = x.shape
|
batch_size, seq_len, emb_size = x.shape
|
||||||
|
|
||||||
# 1. Пропускаем через роутер
|
# 1. Пропускаем через роутер
|
||||||
|
|||||||
@@ -7,64 +7,120 @@ from llm.core.base_model import BaseModel
|
|||||||
from llm.core.token_embeddings import TokenEmbeddings
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
from llm.core.rope import RoPE
|
from llm.core.rope import RoPE
|
||||||
from llm.core.rms_norm import RMSNorm
|
from llm.core.rms_norm import RMSNorm
|
||||||
from llm.core.moe import MoE
|
from llm.core.mixtral_decoder import MixtralDecoder
|
||||||
from llm.core.group_query_attention import GroupedQueryAttention
|
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
|
||||||
def __init__(self,
|
|
||||||
num_q_heads: int,
|
|
||||||
num_kv_heads: int,
|
|
||||||
emb_size: int,
|
|
||||||
head_size: int,
|
|
||||||
max_seq_len: int,
|
|
||||||
num_experts: int,
|
|
||||||
top_k_experts: int,
|
|
||||||
window_size: int,
|
|
||||||
rope: RoPE,
|
|
||||||
dropout: float = 0.1
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self._heads = GroupedQueryAttention(
|
|
||||||
num_q_heads=num_q_heads,
|
|
||||||
num_kv_heads=num_kv_heads,
|
|
||||||
emb_size=emb_size,
|
|
||||||
head_size=head_size,
|
|
||||||
max_seq_len=max_seq_len,
|
|
||||||
window_size=window_size,
|
|
||||||
rope=rope,
|
|
||||||
dropout=dropout
|
|
||||||
)
|
|
||||||
self._ff = MoE(
|
|
||||||
emb_size=emb_size,
|
|
||||||
num_experts=num_experts,
|
|
||||||
top_k_experts=top_k_experts,
|
|
||||||
dropout=dropout
|
|
||||||
)
|
|
||||||
self._norm1 = RMSNorm(emb_size)
|
|
||||||
self._norm2 = RMSNorm(emb_size)
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
|
||||||
norm1_out = self._norm1(x)
|
|
||||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
|
||||||
out = attention + x
|
|
||||||
|
|
||||||
norm2_out = self._norm2(out)
|
|
||||||
ffn_out = self._ff(norm2_out)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
return (ffn_out + out, kv_caches)
|
|
||||||
else:
|
|
||||||
return (ffn_out + out, None)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
class Mixtral(BaseModel):
|
class Mixtral(BaseModel):
|
||||||
|
"""
|
||||||
|
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
|
||||||
|
|
||||||
|
Описание:
|
||||||
|
---------
|
||||||
|
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
|
||||||
|
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
|
||||||
|
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
|
||||||
|
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
|
||||||
|
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
|
||||||
|
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
|
||||||
|
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
----------------------
|
||||||
|
config : dict
|
||||||
|
Словарь-конфиг с основными гиперпараметрами модели:
|
||||||
|
- vocab_size : int — размер словаря токенов
|
||||||
|
- embed_dim : int — размер скрытого пространства
|
||||||
|
- max_position_embeddings : int — макс. длина последовательности
|
||||||
|
- num_layers : int — количество декодерных блоков в стеке
|
||||||
|
- num_q_heads : int — число query-голов в attention
|
||||||
|
- num_kv_heads : int — число kv-голов в attention
|
||||||
|
- num_experts : int — число MoE-экспертов
|
||||||
|
- top_k_experts : int — сколько экспертов активировать на токен
|
||||||
|
- dropout : float — вероятность Dropout
|
||||||
|
- window_size : int — размер окна внимания
|
||||||
|
|
||||||
|
Основные методы:
|
||||||
|
----------------
|
||||||
|
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
|
||||||
|
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
|
||||||
|
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {...} # dict с параметрами
|
||||||
|
>>> model = Mixtral(config)
|
||||||
|
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
|
||||||
|
>>> logits, cache = model(x, use_cache=True)
|
||||||
|
>>> print(logits.shape) # [2, 16, vocab_size]
|
||||||
|
|
||||||
|
>>> # Генерация
|
||||||
|
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
|
||||||
|
|
||||||
|
Литература:
|
||||||
|
-----------
|
||||||
|
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
|
||||||
|
- Switch Transformer: https://arxiv.org/abs/2101.03961
|
||||||
|
- GShard: https://arxiv.org/abs/2006.16668
|
||||||
|
- RoPE: https://arxiv.org/abs/2104.09864
|
||||||
|
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
|
||||||
|
- RMSNorm: https://arxiv.org/abs/1910.07467
|
||||||
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Конструктор класса Mixtral.
|
||||||
|
|
||||||
|
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
|
||||||
|
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
config : dict
|
||||||
|
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
|
||||||
|
vocab_size (int): Размер словаря токенов.
|
||||||
|
embed_dim (int): Размер скрытого пространства (эмбеддингов).
|
||||||
|
max_position_embeddings (int): Максимальная длина токенной последовательности.
|
||||||
|
num_layers (int): Количество декодерных блоков (слоёв) в модели.
|
||||||
|
num_q_heads (int): Число query-голов (attention heads).
|
||||||
|
num_kv_heads (int): Число key-value голов (attention heads).
|
||||||
|
num_experts (int): Количество экспертов в каждом MoE-блоке.
|
||||||
|
top_k_experts (int): Сколько экспертов активируется для одного токена.
|
||||||
|
dropout (float): Dropout для регуляризации.
|
||||||
|
window_size (int): Размер окна внимания (Attention Window).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
|
||||||
|
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
|
||||||
|
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> config = {
|
||||||
|
... "vocab_size": 32000,
|
||||||
|
... "embed_dim": 512,
|
||||||
|
... "max_position_embeddings": 2048,
|
||||||
|
... "num_layers": 24,
|
||||||
|
... "num_q_heads": 8,
|
||||||
|
... "num_kv_heads": 8,
|
||||||
|
... "num_experts": 8,
|
||||||
|
... "top_k_experts": 2,
|
||||||
|
... "dropout": 0.1,
|
||||||
|
... "window_size": 256,
|
||||||
|
... }
|
||||||
|
>>> model = Mixtral(config)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
|
||||||
|
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
self._max_seq_len = config["max_position_embeddings"]
|
self._max_seq_len = config["max_position_embeddings"]
|
||||||
@@ -83,7 +139,7 @@ class Mixtral(BaseModel):
|
|||||||
# emb_size=emb_size
|
# emb_size=emb_size
|
||||||
#)
|
#)
|
||||||
self._dropout = nn.Dropout(config["dropout"])
|
self._dropout = nn.Dropout(config["dropout"])
|
||||||
self._decoders = nn.ModuleList([Decoder(
|
self._decoders = nn.ModuleList([MixtralDecoder(
|
||||||
num_q_heads=config["num_q_heads"],
|
num_q_heads=config["num_q_heads"],
|
||||||
num_kv_heads=config["num_kv_heads"],
|
num_kv_heads=config["num_kv_heads"],
|
||||||
emb_size=config["embed_dim"],
|
emb_size=config["embed_dim"],
|
||||||
@@ -99,6 +155,40 @@ class Mixtral(BaseModel):
|
|||||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через всю модель Mixtral.
|
||||||
|
|
||||||
|
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
|
||||||
|
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
|
||||||
|
Если False — attention cache не используется.
|
||||||
|
cache : list, optional
|
||||||
|
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
tuple:
|
||||||
|
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
|
||||||
|
- new_cache : list или None — обновлённый cache, если используется.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> logits, new_cache = model(x, use_cache=True, cache=None)
|
||||||
|
>>> logits.shape # [batch_size, seq_len, vocab_size]
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
-----------
|
||||||
|
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
|
||||||
|
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
|
||||||
|
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
|
||||||
|
|
||||||
|
"""
|
||||||
# Проверка длины последовательности (только при отсутствии кэша)
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
if cache is None and x.size(1) > self._max_seq_len:
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
@@ -259,33 +349,6 @@ class Mixtral(BaseModel):
|
|||||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def save(self, path):
|
|
||||||
torch.save({
|
|
||||||
'model_state_dict': self.state_dict(),
|
|
||||||
'vocab_size': self._vocab_size,
|
|
||||||
'max_seq_len': self._max_seq_len,
|
|
||||||
'emb_size': self._emb_size,
|
|
||||||
'num_heads': self._num_heads,
|
|
||||||
'head_size': self._head_size,
|
|
||||||
'num_layers': self._num_layers
|
|
||||||
}, path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path, device):
|
|
||||||
checkpoint = torch.load(path, map_location=device)
|
|
||||||
model = cls(
|
|
||||||
vocab_size=checkpoint['vocab_size'],
|
|
||||||
max_seq_len=checkpoint['max_seq_len'],
|
|
||||||
emb_size=checkpoint['emb_size'],
|
|
||||||
num_heads=checkpoint['num_heads'],
|
|
||||||
head_size=checkpoint['head_size'],
|
|
||||||
num_layers=checkpoint['num_layers']
|
|
||||||
)
|
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
|
||||||
model.to(device)
|
|
||||||
return model
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_seq_len(self) -> int:
|
def max_seq_len(self) -> int:
|
||||||
return self._max_seq_len
|
return self._max_seq_len
|
||||||
|
|||||||
80
llm/tests/core/test_mixtral_decoder.py
Normal file
80
llm/tests/core/test_mixtral_decoder.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.mixtral_decoder import MixtralDecoder
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def basic_decoder():
|
||||||
|
emb_size = 16
|
||||||
|
num_q_heads = 4
|
||||||
|
num_kv_heads = 2
|
||||||
|
head_size = 4
|
||||||
|
max_seq_len = 32
|
||||||
|
num_experts = 4
|
||||||
|
top_k_experts = 2
|
||||||
|
window_size = 8
|
||||||
|
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
|
||||||
|
return MixtralDecoder(
|
||||||
|
num_q_heads=num_q_heads,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
num_experts=num_experts,
|
||||||
|
top_k_experts=top_k_experts,
|
||||||
|
window_size=window_size,
|
||||||
|
rope=rope,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_forward_shape(basic_decoder):
|
||||||
|
x = torch.randn(2, 10, 16)
|
||||||
|
out, cache = basic_decoder(x)
|
||||||
|
assert out.shape == (2, 10, 16)
|
||||||
|
assert cache is None or isinstance(cache, (tuple, list))
|
||||||
|
|
||||||
|
def test_forward_masked(basic_decoder):
|
||||||
|
x = torch.randn(3, 7, 16)
|
||||||
|
mask = torch.ones(3, 7, 7, dtype=torch.bool)
|
||||||
|
out, cache = basic_decoder(x, mask=mask)
|
||||||
|
assert out.shape == (3, 7, 16)
|
||||||
|
|
||||||
|
def test_forward_with_cache_flag(basic_decoder):
|
||||||
|
x = torch.randn(2, 8, 16)
|
||||||
|
out, cache = basic_decoder(x, use_cache=True, cache=None)
|
||||||
|
assert out.shape == (2, 8, 16)
|
||||||
|
assert isinstance(cache, (tuple, list)) or cache is None
|
||||||
|
|
||||||
|
def test_backprop_pass(basic_decoder):
|
||||||
|
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||||
|
out, _ = basic_decoder(x)
|
||||||
|
y = out.sum()
|
||||||
|
y.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_seq_too_long_raises(basic_decoder):
|
||||||
|
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
basic_decoder(x)
|
||||||
|
|
||||||
|
def test_different_config():
|
||||||
|
rope = RoPE(head_size=2, max_seq_len=12)
|
||||||
|
decoder = MixtralDecoder(
|
||||||
|
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
|
||||||
|
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
|
||||||
|
)
|
||||||
|
x = torch.randn(1, 8, 4)
|
||||||
|
out, cache = decoder(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_no_dropout():
|
||||||
|
# Проверка на корректность shape при отсутствии Dropout
|
||||||
|
rope = RoPE(head_size=2, max_seq_len=12)
|
||||||
|
decoder = MixtralDecoder(
|
||||||
|
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
|
||||||
|
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
|
||||||
|
)
|
||||||
|
x = torch.randn(2, 3, 4)
|
||||||
|
out, cache = decoder(x)
|
||||||
|
assert out.shape == x.shape
|
||||||
61
llm/tests/core/test_moe.py
Normal file
61
llm/tests/core/test_moe.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.moe import MoE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def moe():
|
||||||
|
# Базовая MoE для коротких тестов
|
||||||
|
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
|
||||||
|
|
||||||
|
def test_forward_shape(moe):
|
||||||
|
x = torch.randn(3, 5, 16) # [batch, seq, emb]
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_grad(moe):
|
||||||
|
x = torch.randn(2, 4, 16, requires_grad=True)
|
||||||
|
y = moe(x)
|
||||||
|
(y.sum()).backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.shape == x.shape
|
||||||
|
|
||||||
|
def test_top_k_larger_than_experts():
|
||||||
|
# top_k_experts > num_experts должно падать
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
MoE(emb_size=8, num_experts=2, top_k_experts=4)
|
||||||
|
|
||||||
|
def test_single_expert_no_error():
|
||||||
|
# один эксперт, один топ-к — модель всё ещё валидна
|
||||||
|
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
|
||||||
|
x = torch.randn(2, 2, 8)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_trivial_weights():
|
||||||
|
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
|
||||||
|
class DummyMoE(MoE):
|
||||||
|
def forward(self, x):
|
||||||
|
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
|
||||||
|
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
|
||||||
|
torch.nn.init.constant_(self._router.weight, 0.0)
|
||||||
|
return super().forward(x)
|
||||||
|
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
|
||||||
|
x = torch.zeros(1, 2, 4)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_forward_deterministic_seed(moe):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
x = torch.randn(2, 3, 16)
|
||||||
|
y1 = moe(x)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
y2 = moe(x)
|
||||||
|
assert torch.allclose(y1, y2, atol=1e-5)
|
||||||
|
|
||||||
|
def test_forward_no_dropout():
|
||||||
|
"""Без dropout MoE не меняет shape и не даёт NaN."""
|
||||||
|
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
|
||||||
|
x = torch.randn(2, 7, 5)
|
||||||
|
y = moe(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert not torch.isnan(y).any()
|
||||||
Reference in New Issue
Block a user