Files
simple-llm/tests/test_decoder.py
Sergey Penkovsky 420c45dc74 Реализация Decoder для трансформера
- Основной модуль декодера (Decoder) с:
  * Self-Attention механизмом
  * Encoder-Decoder Attention слоем
  * LayerNormalization
  * Позиционными эмбеддингами
- Примеры использования с документацией
- Полный набор unit-тестов
- Документация на русском языке
2025-07-21 11:00:49 +03:00

87 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import pytest
from simple_llm.transformer.decoder import Decoder
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
from simple_llm.transformer.feed_forward import FeedForward
from torch import nn
class TestDecoder:
@pytest.fixture
def sample_decoder(self):
"""Фикстура с тестовым декодером"""
return Decoder(
num_heads=4,
emb_size=64,
head_size=16,
max_seq_len=128,
dropout=0.1
)
def test_initialization(self, sample_decoder):
"""Тест инициализации слоёв"""
assert isinstance(sample_decoder._heads, MultiHeadAttention)
assert isinstance(sample_decoder._ff, FeedForward)
assert isinstance(sample_decoder._norm1, nn.LayerNorm)
assert isinstance(sample_decoder._norm2, nn.LayerNorm)
assert sample_decoder._norm1.normalized_shape == (64,)
assert sample_decoder._norm2.normalized_shape == (64,)
def test_forward_shapes(self, sample_decoder):
"""Тест сохранения размерностей"""
test_cases = [
(1, 10, 64), # batch=1, seq_len=10
(4, 25, 64), # batch=4, seq_len=25
(8, 50, 64) # batch=8, seq_len=50
]
for batch, seq_len, emb_size in test_cases:
x = torch.randn(batch, seq_len, emb_size)
output = sample_decoder(x)
assert output.shape == (batch, seq_len, emb_size)
def test_masking(self, sample_decoder):
"""Тест работы маски внимания"""
x = torch.randn(2, 10, 64)
mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска
output = sample_decoder(x, mask)
assert not torch.isnan(output).any()
def test_sample_case(self):
"""Тест на соответствие sample-тесту из условия"""
torch.manual_seed(0)
decoder = Decoder(
num_heads=5,
emb_size=12,
head_size=8,
max_seq_len=20,
dropout=0.0
)
# Инициализируем веса нулями для предсказуемости
for p in decoder.parameters():
nn.init.zeros_(p)
x = torch.zeros(1, 12, 12)
output = decoder(x)
# Проверка формы вывода
assert output.shape == (1, 12, 12)
# Проверка что выход близок к нулю (но не точное значение)
assert torch.allclose(output, torch.zeros_like(output), atol=1e-6), \
"Output should be close to zero with zero-initialized weights"
def test_gradient_flow(self, sample_decoder):
"""Тест потока градиентов"""
x = torch.randn(2, 15, 64, requires_grad=True)
output = sample_decoder(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
assert not torch.isnan(x.grad).any()
if __name__ == "__main__":
pytest.main(["-v", __file__])