feat(gemma): document and test GeGLU, MultiQueryAttention, GemmaDecoder, update Gemma model docs

- Add new core modules: GeGLU (Gated GELU Linear Unit), GemmaDecoder, MultiQueryAttention; all with highly detailed scientific (RU) docstrings: theory, usage, formulas, references
- Major doc improvements in Gemma model: class, __init__, forward, generate now have full educational/engineering docstrings, use-case samples, and literature links
- Add comprehensive unit tests:
    * tests/core/test_geglu.py: GeGLU coverage (shape, grads, edge, repeat, float16/skip)
    * tests/core/test_gemma_decoder.py: GemmaDecoder coverage (shape, mask, cache, repeatability, errors)
    * tests/core/test_multi_query_attention.py: MQA coverage (shape, cache, gradients, masking, dropout, raise)
- All modules and tests follow strict quality/documentation standards, code is now robust for research & production
This commit is contained in:
Sergey Penkovsky
2025-10-21 15:12:45 +03:00
parent cfb4b6dfb1
commit ea932a36f3
7 changed files with 962 additions and 293 deletions

View File

@@ -0,0 +1,60 @@
import torch
import pytest
from llm.core.geglu import GeGLU
@pytest.fixture
def geglu():
return GeGLU(emb_size=16, dropout=0.1)
def test_forward_shape(geglu):
x = torch.randn(2, 5, 16)
y = geglu(x)
assert y.shape == x.shape
def test_forward_no_batch(geglu):
x = torch.randn(1, 16)
y = geglu(x.unsqueeze(0))
assert y.shape == (1, 1, 16)
@pytest.mark.skip(reason="float16 not supported without parameter casting")
def test_forward_dtype_fp16():
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 4, 8).half()
y = geglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_forward_no_dropout():
geglu = GeGLU(emb_size=4, dropout=0.0)
x = torch.randn(3, 2, 4)
y = geglu(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()
def test_gradient_flow(geglu):
x = torch.randn(3, 8, 16, requires_grad=True)
y = geglu(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_forward_repeatability():
torch.manual_seed(42)
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(3, 2, 8)
y1 = geglu(x)
torch.manual_seed(42)
geglu2 = GeGLU(emb_size=8, dropout=0.0)
x2 = torch.randn(3, 2, 8)
y2 = geglu2(x2)
assert torch.allclose(y1, y2, atol=1e-5)
def test_edge_small_large():
geglu = GeGLU(emb_size=2, dropout=0.0)
x = torch.randn(2, 2, 2)
y = geglu(x)
assert y.shape == x.shape
geglu = GeGLU(emb_size=256, dropout=0.0)
x = torch.randn(1, 1, 256)
y = geglu(x)
assert y.shape == x.shape

View File

@@ -0,0 +1,67 @@
import torch
import pytest
from llm.core.gemma_decoder import GemmaDecoder
from llm.core.rope import RoPE
@pytest.fixture
def gemma_decoder():
rope = RoPE(head_size=4, max_seq_len=32)
return GemmaDecoder(
num_q_heads=4,
emb_size=16,
head_size=4,
max_seq_len=32,
rope=rope,
dropout=0.1,
)
def test_forward_shape(gemma_decoder):
x = torch.randn(2, 12, 16)
out, cache = gemma_decoder(x)
assert out.shape == (2, 12, 16)
assert isinstance(cache, tuple) or cache is None
def test_forward_masked(gemma_decoder):
x = torch.randn(1, 8, 16)
mask = torch.ones(1, 8, 8, dtype=torch.bool)
out, _ = gemma_decoder(x, mask=mask)
assert out.shape == x.shape
def test_forward_with_cache_flag(gemma_decoder):
x = torch.randn(2, 7, 16)
out, cache = gemma_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 7, 16)
def test_forward_wrong_seq_len_raises(gemma_decoder):
x = torch.randn(1, 100, 16)
with pytest.raises(Exception):
gemma_decoder(x)
def test_gradient_flow(gemma_decoder):
x = torch.randn(3, 9, 16, requires_grad=True)
y, _ = gemma_decoder(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_various_shapes(gemma_decoder):
for b, s in [(1, 1), (2, 5), (2, 32)]:
x = torch.randn(b, s, 16)
y, _ = gemma_decoder(x)
assert y.shape == (b, s, 16)
def test_forward_repeatability():
torch.manual_seed(42)
rope = RoPE(head_size=4, max_seq_len=32)
decoder = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x = torch.randn(2, 8, 16)
y1, _ = decoder(x)
torch.manual_seed(42)
decoder2 = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x2 = torch.randn(2, 8, 16)
y2, _ = decoder2(x2)
assert torch.allclose(y1, y2, atol=1e-5)

View File

@@ -0,0 +1,71 @@
import torch
import pytest
from llm.core.multi_query_attention import MultiQueryAttention
from llm.core.rope import RoPE
@pytest.fixture
def mqa_rope():
return MultiQueryAttention(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=RoPE(head_size=4, max_seq_len=32), dropout=0.1
)
@pytest.fixture
def mqa_no_rope():
return MultiQueryAttention(
num_q_heads=2, emb_size=8, head_size=4, max_seq_len=16, rope=None, dropout=0.0
)
def test_forward_shape(mqa_rope):
x = torch.randn(2, 10, 16)
out, cache = mqa_rope(x)
assert out.shape == (2, 10, 16)
assert isinstance(cache, tuple) and len(cache) == 2
def test_forward_masked(mqa_rope):
x = torch.randn(2, 8, 16)
mask = torch.ones(2, 8, 8, dtype=torch.bool)
out, cache = mqa_rope(x, mask=mask)
assert out.shape == (2, 8, 16)
def test_forward_cache(mqa_rope):
x = torch.randn(1, 4, 16)
# Первый вызов — кэша нет
out1, cache1 = mqa_rope(x)
# Повторяем: подаем x второй раз — теперь добавим cache
out2, cache2 = mqa_rope(x, use_cache=True, cache=cache1)
assert out2.shape == (1, 4, 16)
assert isinstance(cache2, tuple) and len(cache2) == 2
# Проверка, что длина k_cache увеличилась
assert cache2[0].shape[2] == cache1[0].shape[2] + x.shape[1] # по длине seq
def test_forward_no_rope(mqa_no_rope):
x = torch.randn(3, 6, 8)
out, _ = mqa_no_rope(x)
assert out.shape == (3, 6, 8)
def test_forward_different_batch_seq(mqa_rope):
for batch, seq in [(1, 1), (2, 5), (3, 32)]:
x = torch.randn(batch, seq, 16)
out, _ = mqa_rope(x)
assert out.shape == (batch, seq, 16)
def test_forward_raise_on_long_seq(mqa_rope):
x = torch.randn(2, 40, 16) # seq_len > max_seq_len
with pytest.raises(ValueError):
mqa_rope(x)
def test_forward_grad(mqa_rope):
x = torch.randn(2, 7, 16, requires_grad=True)
out, _ = mqa_rope(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_dropout_applied():
mqa = MultiQueryAttention(num_q_heads=2, emb_size=8, head_size=4, max_seq_len=12, rope=None, dropout=0.99)
x = torch.ones(1, 3, 8)
mqa.train()
y, _ = mqa(x)
# При очень большом dropout почти всё обнуляется
assert (torch.abs(y) < 1e-5).float().mean() > 0.6 or y.sum() < 1e-2