Files
simple-llm/tests/test_head_attention.py

71 lines
2.5 KiB
Python
Raw Normal View History

import torch
import pytest
from simple_llm.transformer.head_attention import HeadAttention
# Проверка версии PyTorch для совместимости
def check_pytorch_version():
print(f"\nPyTorch version: {torch.__version__}")
if torch.__version__ < '1.2.0':
pytest.skip("This test requires PyTorch 1.2.0 or higher")
@pytest.fixture(autouse=True)
def setup():
check_pytorch_version()
@pytest.fixture
def attention():
return HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
def test_attention_initialization(attention):
assert attention._emb_size == 64
assert attention._head_size == 32
assert attention._max_seq_len == 128
assert isinstance(attention._k, torch.nn.Linear)
assert isinstance(attention._q, torch.nn.Linear)
assert isinstance(attention._v, torch.nn.Linear)
def test_attention_forward_shape(attention):
batch_size = 4
seq_len = 32
x = torch.randn(batch_size, seq_len, 64) # (B, T, emb_size)
output = attention(x)
assert output.shape == (batch_size, seq_len, 32) # (B, T, head_size)
def test_attention_mask():
attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
# Проверяем, что маска нижнетреугольная
mask = attention._tril_mask
assert mask.shape == (128, 128)
assert torch.all(mask == torch.tril(torch.ones(128, 128)).bool())
def test_attention_causal_property(attention):
"""Test causal attention property"""
print("\nTesting causal attention...")
seq_len = 10
x = torch.randn(1, seq_len, 64)
# Получаем query и key
q = attention._q(x) # (1, seq_len, head_size)
k = attention._k(x) # (1, seq_len, head_size)
# Вычисляем scores
scores = q @ k.transpose(-2, -1) / (attention._head_size ** 0.5)
# Применяем маску (это должно происходить внутри forward)
mask = attention._tril_mask[:seq_len, :seq_len]
masked_scores = scores.masked_fill(~mask, float('-inf'))
# Проверяем, что будущие токены замаскированы
for i in range(seq_len):
for j in range(i+1, seq_len):
assert masked_scores[0, i, j] == float('-inf'), f"Position {i},{j} not masked"
print("Causal attention test passed")
def test_attention_sequence_length_limit(attention):
with pytest.raises(ValueError):
x = torch.randn(1, 129, 64) # Превышаем max_seq_len
attention(x)
if __name__ == "__main__":
pytest.main(["-v", __file__])