diff --git a/tests/test_head_attention.py b/tests/test_head_attention.py new file mode 100644 index 0000000..f3d0668 --- /dev/null +++ b/tests/test_head_attention.py @@ -0,0 +1,70 @@ +import torch +import pytest +from simple_llm.transformer.head_attention import HeadAttention + +# Проверка версии PyTorch для совместимости +def check_pytorch_version(): + print(f"\nPyTorch version: {torch.__version__}") + if torch.__version__ < '1.2.0': + pytest.skip("This test requires PyTorch 1.2.0 or higher") + +@pytest.fixture(autouse=True) +def setup(): + check_pytorch_version() + +@pytest.fixture +def attention(): + return HeadAttention(emb_size=64, head_size=32, max_seq_len=128) + +def test_attention_initialization(attention): + assert attention._emb_size == 64 + assert attention._head_size == 32 + assert attention._max_seq_len == 128 + assert isinstance(attention._k, torch.nn.Linear) + assert isinstance(attention._q, torch.nn.Linear) + assert isinstance(attention._v, torch.nn.Linear) + +def test_attention_forward_shape(attention): + batch_size = 4 + seq_len = 32 + x = torch.randn(batch_size, seq_len, 64) # (B, T, emb_size) + output = attention(x) + assert output.shape == (batch_size, seq_len, 32) # (B, T, head_size) + +def test_attention_mask(): + attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128) + # Проверяем, что маска нижнетреугольная + mask = attention._tril_mask + assert mask.shape == (128, 128) + assert torch.all(mask == torch.tril(torch.ones(128, 128)).bool()) + +def test_attention_causal_property(attention): + """Test causal attention property""" + print("\nTesting causal attention...") + seq_len = 10 + x = torch.randn(1, seq_len, 64) + + # Получаем query и key + q = attention._q(x) # (1, seq_len, head_size) + k = attention._k(x) # (1, seq_len, head_size) + + # Вычисляем scores + scores = q @ k.transpose(-2, -1) / (attention._head_size ** 0.5) + + # Применяем маску (это должно происходить внутри forward) + mask = attention._tril_mask[:seq_len, :seq_len] + masked_scores = scores.masked_fill(~mask, float('-inf')) + + # Проверяем, что будущие токены замаскированы + for i in range(seq_len): + for j in range(i+1, seq_len): + assert masked_scores[0, i, j] == float('-inf'), f"Position {i},{j} not masked" + print("Causal attention test passed") + +def test_attention_sequence_length_limit(attention): + with pytest.raises(ValueError): + x = torch.randn(1, 129, 64) # Превышаем max_seq_len + attention(x) + +if __name__ == "__main__": + pytest.main(["-v", __file__])