mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
- Основной модуль декодера (Decoder) с: * Self-Attention механизмом * Encoder-Decoder Attention слоем * LayerNormalization * Позиционными эмбеддингами - Примеры использования с документацией - Полный набор unit-тестов - Документация на русском языке
87 lines
3.1 KiB
Python
87 lines
3.1 KiB
Python
import torch
|
||
import pytest
|
||
from simple_llm.transformer.decoder import Decoder
|
||
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||
from simple_llm.transformer.feed_forward import FeedForward
|
||
from torch import nn
|
||
|
||
class TestDecoder:
|
||
@pytest.fixture
|
||
def sample_decoder(self):
|
||
"""Фикстура с тестовым декодером"""
|
||
return Decoder(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=128,
|
||
dropout=0.1
|
||
)
|
||
|
||
def test_initialization(self, sample_decoder):
|
||
"""Тест инициализации слоёв"""
|
||
assert isinstance(sample_decoder._heads, MultiHeadAttention)
|
||
assert isinstance(sample_decoder._ff, FeedForward)
|
||
assert isinstance(sample_decoder._norm1, nn.LayerNorm)
|
||
assert isinstance(sample_decoder._norm2, nn.LayerNorm)
|
||
assert sample_decoder._norm1.normalized_shape == (64,)
|
||
assert sample_decoder._norm2.normalized_shape == (64,)
|
||
|
||
def test_forward_shapes(self, sample_decoder):
|
||
"""Тест сохранения размерностей"""
|
||
test_cases = [
|
||
(1, 10, 64), # batch=1, seq_len=10
|
||
(4, 25, 64), # batch=4, seq_len=25
|
||
(8, 50, 64) # batch=8, seq_len=50
|
||
]
|
||
|
||
for batch, seq_len, emb_size in test_cases:
|
||
x = torch.randn(batch, seq_len, emb_size)
|
||
output = sample_decoder(x)
|
||
assert output.shape == (batch, seq_len, emb_size)
|
||
|
||
def test_masking(self, sample_decoder):
|
||
"""Тест работы маски внимания"""
|
||
x = torch.randn(2, 10, 64)
|
||
mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска
|
||
|
||
output = sample_decoder(x, mask)
|
||
assert not torch.isnan(output).any()
|
||
|
||
def test_sample_case(self):
|
||
"""Тест на соответствие sample-тесту из условия"""
|
||
torch.manual_seed(0)
|
||
decoder = Decoder(
|
||
num_heads=5,
|
||
emb_size=12,
|
||
head_size=8,
|
||
max_seq_len=20,
|
||
dropout=0.0
|
||
)
|
||
|
||
# Инициализируем веса нулями для предсказуемости
|
||
for p in decoder.parameters():
|
||
nn.init.zeros_(p)
|
||
|
||
x = torch.zeros(1, 12, 12)
|
||
output = decoder(x)
|
||
|
||
# Проверка формы вывода
|
||
assert output.shape == (1, 12, 12)
|
||
|
||
# Проверка что выход близок к нулю (но не точное значение)
|
||
assert torch.allclose(output, torch.zeros_like(output), atol=1e-6), \
|
||
"Output should be close to zero with zero-initialized weights"
|
||
|
||
def test_gradient_flow(self, sample_decoder):
|
||
"""Тест потока градиентов"""
|
||
x = torch.randn(2, 15, 64, requires_grad=True)
|
||
output = sample_decoder(x)
|
||
loss = output.sum()
|
||
loss.backward()
|
||
|
||
assert x.grad is not None
|
||
assert not torch.isnan(x.grad).any()
|
||
|
||
if __name__ == "__main__":
|
||
pytest.main(["-v", __file__])
|