mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 05:21:16 +00:00
feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests
- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples - Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components - Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture - Add thorough unit tests: * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc. * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check - All code and tests in compliance with recent scientific and engineering standards.
This commit is contained in:
80
llm/tests/core/test_mixtral_decoder.py
Normal file
80
llm/tests/core/test_mixtral_decoder.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.mixtral_decoder import MixtralDecoder
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def basic_decoder():
|
||||
emb_size = 16
|
||||
num_q_heads = 4
|
||||
num_kv_heads = 2
|
||||
head_size = 4
|
||||
max_seq_len = 32
|
||||
num_experts = 4
|
||||
top_k_experts = 2
|
||||
window_size = 8
|
||||
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
|
||||
return MixtralDecoder(
|
||||
num_q_heads=num_q_heads,
|
||||
num_kv_heads=num_kv_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
num_experts=num_experts,
|
||||
top_k_experts=top_k_experts,
|
||||
window_size=window_size,
|
||||
rope=rope,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
def test_forward_shape(basic_decoder):
|
||||
x = torch.randn(2, 10, 16)
|
||||
out, cache = basic_decoder(x)
|
||||
assert out.shape == (2, 10, 16)
|
||||
assert cache is None or isinstance(cache, (tuple, list))
|
||||
|
||||
def test_forward_masked(basic_decoder):
|
||||
x = torch.randn(3, 7, 16)
|
||||
mask = torch.ones(3, 7, 7, dtype=torch.bool)
|
||||
out, cache = basic_decoder(x, mask=mask)
|
||||
assert out.shape == (3, 7, 16)
|
||||
|
||||
def test_forward_with_cache_flag(basic_decoder):
|
||||
x = torch.randn(2, 8, 16)
|
||||
out, cache = basic_decoder(x, use_cache=True, cache=None)
|
||||
assert out.shape == (2, 8, 16)
|
||||
assert isinstance(cache, (tuple, list)) or cache is None
|
||||
|
||||
def test_backprop_pass(basic_decoder):
|
||||
x = torch.randn(2, 5, 16, requires_grad=True)
|
||||
out, _ = basic_decoder(x)
|
||||
y = out.sum()
|
||||
y.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_seq_too_long_raises(basic_decoder):
|
||||
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
|
||||
with pytest.raises(Exception):
|
||||
basic_decoder(x)
|
||||
|
||||
def test_different_config():
|
||||
rope = RoPE(head_size=2, max_seq_len=12)
|
||||
decoder = MixtralDecoder(
|
||||
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
|
||||
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
|
||||
)
|
||||
x = torch.randn(1, 8, 4)
|
||||
out, cache = decoder(x)
|
||||
assert out.shape == x.shape
|
||||
|
||||
def test_forward_no_dropout():
|
||||
# Проверка на корректность shape при отсутствии Dropout
|
||||
rope = RoPE(head_size=2, max_seq_len=12)
|
||||
decoder = MixtralDecoder(
|
||||
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
|
||||
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
|
||||
)
|
||||
x = torch.randn(2, 3, 4)
|
||||
out, cache = decoder(x)
|
||||
assert out.shape == x.shape
|
||||
61
llm/tests/core/test_moe.py
Normal file
61
llm/tests/core/test_moe.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.moe import MoE
|
||||
|
||||
@pytest.fixture
|
||||
def moe():
|
||||
# Базовая MoE для коротких тестов
|
||||
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
|
||||
|
||||
def test_forward_shape(moe):
|
||||
x = torch.randn(3, 5, 16) # [batch, seq, emb]
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_grad(moe):
|
||||
x = torch.randn(2, 4, 16, requires_grad=True)
|
||||
y = moe(x)
|
||||
(y.sum()).backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_top_k_larger_than_experts():
|
||||
# top_k_experts > num_experts должно падать
|
||||
with pytest.raises(ValueError):
|
||||
MoE(emb_size=8, num_experts=2, top_k_experts=4)
|
||||
|
||||
def test_single_expert_no_error():
|
||||
# один эксперт, один топ-к — модель всё ещё валидна
|
||||
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
|
||||
x = torch.randn(2, 2, 8)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_trivial_weights():
|
||||
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
|
||||
class DummyMoE(MoE):
|
||||
def forward(self, x):
|
||||
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
|
||||
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
|
||||
torch.nn.init.constant_(self._router.weight, 0.0)
|
||||
return super().forward(x)
|
||||
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
|
||||
x = torch.zeros(1, 2, 4)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_forward_deterministic_seed(moe):
|
||||
torch.manual_seed(42)
|
||||
x = torch.randn(2, 3, 16)
|
||||
y1 = moe(x)
|
||||
torch.manual_seed(42)
|
||||
y2 = moe(x)
|
||||
assert torch.allclose(y1, y2, atol=1e-5)
|
||||
|
||||
def test_forward_no_dropout():
|
||||
"""Без dropout MoE не меняет shape и не даёт NaN."""
|
||||
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
|
||||
x = torch.randn(2, 7, 5)
|
||||
y = moe(x)
|
||||
assert y.shape == x.shape
|
||||
assert not torch.isnan(y).any()
|
||||
Reference in New Issue
Block a user