feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests

- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples
- Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components
- Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture
- Add thorough unit tests:
   * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc.
   * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check
- All code and tests in compliance with recent scientific and engineering standards.
This commit is contained in:
Sergey Penkovsky
2025-10-20 16:07:51 +03:00
parent b1737bbce2
commit c9da4c841b
5 changed files with 632 additions and 80 deletions

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def basic_decoder():
emb_size = 16
num_q_heads = 4
num_kv_heads = 2
head_size = 4
max_seq_len = 32
num_experts = 4
top_k_experts = 2
window_size = 8
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
return MixtralDecoder(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
num_experts=num_experts,
top_k_experts=top_k_experts,
window_size=window_size,
rope=rope,
dropout=0.0,
)
def test_forward_shape(basic_decoder):
x = torch.randn(2, 10, 16)
out, cache = basic_decoder(x)
assert out.shape == (2, 10, 16)
assert cache is None or isinstance(cache, (tuple, list))
def test_forward_masked(basic_decoder):
x = torch.randn(3, 7, 16)
mask = torch.ones(3, 7, 7, dtype=torch.bool)
out, cache = basic_decoder(x, mask=mask)
assert out.shape == (3, 7, 16)
def test_forward_with_cache_flag(basic_decoder):
x = torch.randn(2, 8, 16)
out, cache = basic_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 8, 16)
assert isinstance(cache, (tuple, list)) or cache is None
def test_backprop_pass(basic_decoder):
x = torch.randn(2, 5, 16, requires_grad=True)
out, _ = basic_decoder(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_seq_too_long_raises(basic_decoder):
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
with pytest.raises(Exception):
basic_decoder(x)
def test_different_config():
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
)
x = torch.randn(1, 8, 4)
out, cache = decoder(x)
assert out.shape == x.shape
def test_forward_no_dropout():
# Проверка на корректность shape при отсутствии Dropout
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
)
x = torch.randn(2, 3, 4)
out, cache = decoder(x)
assert out.shape == x.shape

View File

@@ -0,0 +1,61 @@
import torch
import pytest
from llm.core.moe import MoE
@pytest.fixture
def moe():
# Базовая MoE для коротких тестов
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
def test_forward_shape(moe):
x = torch.randn(3, 5, 16) # [batch, seq, emb]
y = moe(x)
assert y.shape == x.shape
def test_forward_grad(moe):
x = torch.randn(2, 4, 16, requires_grad=True)
y = moe(x)
(y.sum()).backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_top_k_larger_than_experts():
# top_k_experts > num_experts должно падать
with pytest.raises(ValueError):
MoE(emb_size=8, num_experts=2, top_k_experts=4)
def test_single_expert_no_error():
# один эксперт, один топ-к — модель всё ещё валидна
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
x = torch.randn(2, 2, 8)
y = moe(x)
assert y.shape == x.shape
def test_forward_trivial_weights():
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
class DummyMoE(MoE):
def forward(self, x):
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
torch.nn.init.constant_(self._router.weight, 0.0)
return super().forward(x)
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
x = torch.zeros(1, 2, 4)
y = moe(x)
assert y.shape == x.shape
def test_forward_deterministic_seed(moe):
torch.manual_seed(42)
x = torch.randn(2, 3, 16)
y1 = moe(x)
torch.manual_seed(42)
y2 = moe(x)
assert torch.allclose(y1, y2, atol=1e-5)
def test_forward_no_dropout():
"""Без dropout MoE не меняет shape и не даёт NaN."""
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
x = torch.randn(2, 7, 5)
y = moe(x)
assert y.shape == x.shape
assert not torch.isnan(y).any()