feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests

- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples
- Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components
- Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture
- Add thorough unit tests:
   * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc.
   * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check
- All code and tests in compliance with recent scientific and engineering standards.
This commit is contained in:
Sergey Penkovsky
2025-10-20 16:07:51 +03:00
parent b1737bbce2
commit c9da4c841b
5 changed files with 632 additions and 80 deletions

View File

@@ -0,0 +1,211 @@
from torch import nn
import torch
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
from llm.core.moe import MoE
from llm.core.rms_norm import RMSNorm
class MixtralDecoder(nn.Module):
"""
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
Назначение:
-----------
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
Архитектура блока:
------------------
- RMSNorm -> Grouped Query Attention (GQA)
- skip-connection
- RMSNorm -> MoE (SwiGLU-эксперты)
- skip-connection
Для входа `x` проходит:
1. norm1_out = RMSNorm(x)
2. attention, kv_caches = GQA(norm1_out, ...)
3. out = attention + x # residual connection
4. norm2_out = RMSNorm(out)
5. ffn_out = MoE(norm2_out)
6. return (ffn_out + out, kv_caches)
Теоретическая мотивация:
------------------------
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
Аргументы конструктора:
----------------------
num_q_heads : int
Число query-голов в attention.
num_kv_heads : int
Число key-value голов (группировка ключей/values).
emb_size : int
Скрытый размер эмбеддинга.
head_size : int
Размерность одной головы (emb_size // num_q_heads).
max_seq_len : int
Максимальная поддерживаемая длина последовательности.
num_experts : int
Количество «экспертов» (MoE).
top_k_experts : int
Сколько одновременно экспертов активируется для одного токена.
window_size : int
Размер окна внимания (используется для efficient attention).
rope : RoPE
Реализация позиционного кодирования RoPE.
dropout : float
Вероятность Dropout для регуляризации.
Пример использования:
---------------------
>>> decoder = MixtralDecoder(... параметры ...)
>>> x = torch.randn(batch, seq, emb_size)
>>> out, cache = decoder(x, mask=None, use_cache=True)
>>> out.shape
Литература и ссылки:
--------------------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Mistral paper: https://arxiv.org/abs/2310.06825
- GQA: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор декодерного блока MixtralDecoder.
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
Аргументы:
----------
num_q_heads : int
Количество голов внимания (queries) в механизме GroupedQueryAttention.
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
num_kv_heads : int
Количество групп ключей/значений (key-value heads) для GQA.
Позволяет балансировать производительность и память.
emb_size : int
Размерность эмбеддингового пространства внутри слоя (hidden).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина токенизированной последовательности.
num_experts : int
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
top_k_experts : int
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
window_size : int
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
rope : RoPE
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
dropout : float, по умолчанию 0.1
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
Пример:
-------
>>> decoder = MixtralDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... num_experts=4,
... top_k_experts=2,
... window_size=128,
... rope=rope_module,
... dropout=0.05
... )
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через декодерный блок MixtralDecoder.
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
- нормализацию (RMSNorm),
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
- остаточное сложение (residual connection),
- повторную нормализацию,
- feed-forward блок на основе Mixture-of-Experts (MoE),
- финальное остаточное сложение.
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
mask : torch.Tensor, optional
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
use_cache : bool, по умолчанию True
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
cache : list, optional
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
Возвращает:
-----------
Tuple[torch.Tensor, Any]:
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
Пример:
-------
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

View File

@@ -4,6 +4,58 @@ import torch.nn.functional as F
from llm.core.swi_glu import SwiGLU from llm.core.swi_glu import SwiGLU
class MoE(nn.Module): class MoE(nn.Module):
"""
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
Назначение:
-----------
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
Архитектурная схема:
---------------------
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
- Dropout применяется к итоговому выходу.
Математика (коротко):
---------------------
Пусть X ∈ R^{BxSxD} — вход,
E — число экспертов,
K — число активируемых экспертов на токен.
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
Для каждого токена:
y_j = Expert_j(x)
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
Output: Y ∈ R^{BxSxD}
Аргументы конструктора:
----------------------
emb_size : int
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
num_experts : int
Общее число экспертов внутри слоя MoE.
top_k_experts : int
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
dropout : float, по умолчанию 0.1
Dropout к выходу агрегатора.
Пример использования:
---------------------
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
>>> x = torch.randn(4, 16, 512)
>>> y = moe(x)
>>> y.shape # torch.Size([4, 16, 512])
Литература:
-----------
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
"""
def __init__( def __init__(
self, self,
emb_size: int, emb_size: int,
@@ -11,7 +63,51 @@ class MoE(nn.Module):
top_k_experts: int, top_k_experts: int,
dropout: float = 0.1, dropout: float = 0.1,
): ):
"""
Конструктор слоя MoE (Mixture of Experts).
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
который будет для каждого токена определять наиболее релевантных экспертов.
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
Аргументы:
----------
emb_size : int
Размерность входных и выходных векторов (embedding size).
Определяет, над каким пространством признаков будет работать роутер и эксперты.
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
num_experts : int
Общее количество экспертов в слое MoE.
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
Пример: 8, 16, 32, 64.
top_k_experts : int
Сколько экспертов одновременно будет обрабатывать каждый токен.
Обычно 28. Меньшее значение — выше разреженность, больше экономия вычислений.
dropout : float, по умолчанию 0.1
Вероятность зануления значений на выходе после агрегации откликов экспертов.
Используется для регуляризации (борьбы с переобучением).
Пример:
-------
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
>>> print(moe)
MoE( ... )
Теория:
-------
Слой строит:
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
"""
super().__init__() super().__init__()
if top_k_experts > num_experts:
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
self._num_experts = num_experts self._num_experts = num_experts
self._top_k_experts = top_k_experts self._top_k_experts = top_k_experts
@@ -23,6 +119,47 @@ class MoE(nn.Module):
self._dropout = nn.Dropout(dropout) self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через слой MoE.
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
Математически:
--------------
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
router_logits = Linear(x) ∈ ^{batch, seq, num_experts}
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
3. Каждый эксперт обрабатывает только свой поднабор токенов.
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
5. На результат применяется dropout для регуляризации.
Аргументы:
----------
x : torch.Tensor
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
с учетом softmax-весов маршрутизатора и dropout'а.
Пример:
-------
>>> y = moe(x)
>>> print(y.shape)
torch.Size([batch_size, seq_length, emb_size])
Примечание:
-----------
- Каждый токен чаще всего активирует только подмножество экспертов.
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
"""
batch_size, seq_len, emb_size = x.shape batch_size, seq_len, emb_size = x.shape
# 1. Пропускаем через роутер # 1. Пропускаем через роутер

View File

@@ -7,64 +7,120 @@ from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm from llm.core.rms_norm import RMSNorm
from llm.core.moe import MoE from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.group_query_attention import GroupedQueryAttention
class Decoder(nn.Module):
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)
from torch import nn
import torch
import torch.nn.functional as F
class Mixtral(BaseModel): class Mixtral(BaseModel):
"""
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
Описание:
---------
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
Архитектурные особенности:
--------------------------
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
Аргументы конструктора:
----------------------
config : dict
Словарь-конфиг с основными гиперпараметрами модели:
- vocab_size : int — размер словаря токенов
- embed_dim : int — размер скрытого пространства
- max_position_embeddings : int — макс. длина последовательности
- num_layers : int — количество декодерных блоков в стеке
- num_q_heads : int — число query-голов в attention
- num_kv_heads : int — число kv-голов в attention
- num_experts : int — число MoE-экспертов
- top_k_experts : int — сколько экспертов активировать на токен
- dropout : float — вероятность Dropout
- window_size : int — размер окна внимания
Основные методы:
----------------
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
Пример:
-------
>>> config = {...} # dict с параметрами
>>> model = Mixtral(config)
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [2, 16, vocab_size]
>>> # Генерация
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
Литература:
-----------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Switch Transformer: https://arxiv.org/abs/2101.03961
- GShard: https://arxiv.org/abs/2006.16668
- RoPE: https://arxiv.org/abs/2104.09864
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self, config): def __init__(self, config):
"""
Конструктор класса Mixtral.
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
Аргументы:
----------
config : dict
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
vocab_size (int): Размер словаря токенов.
embed_dim (int): Размер скрытого пространства (эмбеддингов).
max_position_embeddings (int): Максимальная длина токенной последовательности.
num_layers (int): Количество декодерных блоков (слоёв) в модели.
num_q_heads (int): Число query-голов (attention heads).
num_kv_heads (int): Число key-value голов (attention heads).
num_experts (int): Количество экспертов в каждом MoE-блоке.
top_k_experts (int): Сколько экспертов активируется для одного токена.
dropout (float): Dropout для регуляризации.
window_size (int): Размер окна внимания (Attention Window).
Внутри:
-------
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 8,
... "num_experts": 8,
... "top_k_experts": 2,
... "dropout": 0.1,
... "window_size": 256,
... }
>>> model = Mixtral(config)
Примечания:
-----------
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
"""
super().__init__(config) super().__init__(config)
self._max_seq_len = config["max_position_embeddings"] self._max_seq_len = config["max_position_embeddings"]
@@ -83,7 +139,7 @@ class Mixtral(BaseModel):
# emb_size=emb_size # emb_size=emb_size
#) #)
self._dropout = nn.Dropout(config["dropout"]) self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([Decoder( self._decoders = nn.ModuleList([MixtralDecoder(
num_q_heads=config["num_q_heads"], num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"], num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"], emb_size=config["embed_dim"],
@@ -99,6 +155,40 @@ class Mixtral(BaseModel):
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"]) self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple: def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mixtral.
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
Аргументы:
----------
x : torch.Tensor
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
use_cache : bool, по умолчанию True
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
Если False — attention cache не используется.
cache : list, optional
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
Возвращает:
-----------
tuple:
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
- new_cache : list или None — обновлённый cache, если используется.
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
"""
# Проверка длины последовательности (только при отсутствии кэша) # Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len: if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}") raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
@@ -259,33 +349,6 @@ class Mixtral(BaseModel):
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1] x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x return x
def save(self, path):
torch.save({
'model_state_dict': self.state_dict(),
'vocab_size': self._vocab_size,
'max_seq_len': self._max_seq_len,
'emb_size': self._emb_size,
'num_heads': self._num_heads,
'head_size': self._head_size,
'num_layers': self._num_layers
}, path)
@classmethod
def load(cls, path, device):
checkpoint = torch.load(path, map_location=device)
model = cls(
vocab_size=checkpoint['vocab_size'],
max_seq_len=checkpoint['max_seq_len'],
emb_size=checkpoint['emb_size'],
num_heads=checkpoint['num_heads'],
head_size=checkpoint['head_size'],
num_layers=checkpoint['num_layers']
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
return model
@property @property
def max_seq_len(self) -> int: def max_seq_len(self) -> int:
return self._max_seq_len return self._max_seq_len

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def basic_decoder():
emb_size = 16
num_q_heads = 4
num_kv_heads = 2
head_size = 4
max_seq_len = 32
num_experts = 4
top_k_experts = 2
window_size = 8
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
return MixtralDecoder(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
num_experts=num_experts,
top_k_experts=top_k_experts,
window_size=window_size,
rope=rope,
dropout=0.0,
)
def test_forward_shape(basic_decoder):
x = torch.randn(2, 10, 16)
out, cache = basic_decoder(x)
assert out.shape == (2, 10, 16)
assert cache is None or isinstance(cache, (tuple, list))
def test_forward_masked(basic_decoder):
x = torch.randn(3, 7, 16)
mask = torch.ones(3, 7, 7, dtype=torch.bool)
out, cache = basic_decoder(x, mask=mask)
assert out.shape == (3, 7, 16)
def test_forward_with_cache_flag(basic_decoder):
x = torch.randn(2, 8, 16)
out, cache = basic_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 8, 16)
assert isinstance(cache, (tuple, list)) or cache is None
def test_backprop_pass(basic_decoder):
x = torch.randn(2, 5, 16, requires_grad=True)
out, _ = basic_decoder(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_seq_too_long_raises(basic_decoder):
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
with pytest.raises(Exception):
basic_decoder(x)
def test_different_config():
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
)
x = torch.randn(1, 8, 4)
out, cache = decoder(x)
assert out.shape == x.shape
def test_forward_no_dropout():
# Проверка на корректность shape при отсутствии Dropout
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
)
x = torch.randn(2, 3, 4)
out, cache = decoder(x)
assert out.shape == x.shape

View File

@@ -0,0 +1,61 @@
import torch
import pytest
from llm.core.moe import MoE
@pytest.fixture
def moe():
# Базовая MoE для коротких тестов
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
def test_forward_shape(moe):
x = torch.randn(3, 5, 16) # [batch, seq, emb]
y = moe(x)
assert y.shape == x.shape
def test_forward_grad(moe):
x = torch.randn(2, 4, 16, requires_grad=True)
y = moe(x)
(y.sum()).backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_top_k_larger_than_experts():
# top_k_experts > num_experts должно падать
with pytest.raises(ValueError):
MoE(emb_size=8, num_experts=2, top_k_experts=4)
def test_single_expert_no_error():
# один эксперт, один топ-к — модель всё ещё валидна
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
x = torch.randn(2, 2, 8)
y = moe(x)
assert y.shape == x.shape
def test_forward_trivial_weights():
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
class DummyMoE(MoE):
def forward(self, x):
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
torch.nn.init.constant_(self._router.weight, 0.0)
return super().forward(x)
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
x = torch.zeros(1, 2, 4)
y = moe(x)
assert y.shape == x.shape
def test_forward_deterministic_seed(moe):
torch.manual_seed(42)
x = torch.randn(2, 3, 16)
y1 = moe(x)
torch.manual_seed(42)
y2 = moe(x)
assert torch.allclose(y1, y2, atol=1e-5)
def test_forward_no_dropout():
"""Без dropout MoE не меняет shape и не даёт NaN."""
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
x = torch.randn(2, 7, 5)
y = moe(x)
assert y.shape == x.shape
assert not torch.isnan(y).any()