mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
71 lines
2.5 KiB
Python
71 lines
2.5 KiB
Python
import torch
|
|
import pytest
|
|
from simple_llm.transformer.head_attention import HeadAttention
|
|
|
|
# Проверка версии PyTorch для совместимости
|
|
def check_pytorch_version():
|
|
print(f"\nPyTorch version: {torch.__version__}")
|
|
if torch.__version__ < '1.2.0':
|
|
pytest.skip("This test requires PyTorch 1.2.0 or higher")
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def setup():
|
|
check_pytorch_version()
|
|
|
|
@pytest.fixture
|
|
def attention():
|
|
return HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
|
|
|
def test_attention_initialization(attention):
|
|
assert attention._emb_size == 64
|
|
assert attention._head_size == 32
|
|
assert attention._max_seq_len == 128
|
|
assert isinstance(attention._k, torch.nn.Linear)
|
|
assert isinstance(attention._q, torch.nn.Linear)
|
|
assert isinstance(attention._v, torch.nn.Linear)
|
|
|
|
def test_attention_forward_shape(attention):
|
|
batch_size = 4
|
|
seq_len = 32
|
|
x = torch.randn(batch_size, seq_len, 64) # (B, T, emb_size)
|
|
output = attention(x)
|
|
assert output.shape == (batch_size, seq_len, 32) # (B, T, head_size)
|
|
|
|
def test_attention_mask():
|
|
attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
|
# Проверяем, что маска нижнетреугольная
|
|
mask = attention._tril_mask
|
|
assert mask.shape == (128, 128)
|
|
assert torch.all(mask == torch.tril(torch.ones(128, 128)).bool())
|
|
|
|
def test_attention_causal_property(attention):
|
|
"""Test causal attention property"""
|
|
print("\nTesting causal attention...")
|
|
seq_len = 10
|
|
x = torch.randn(1, seq_len, 64)
|
|
|
|
# Получаем query и key
|
|
q = attention._q(x) # (1, seq_len, head_size)
|
|
k = attention._k(x) # (1, seq_len, head_size)
|
|
|
|
# Вычисляем scores
|
|
scores = q @ k.transpose(-2, -1) / (attention._head_size ** 0.5)
|
|
|
|
# Применяем маску (это должно происходить внутри forward)
|
|
mask = attention._tril_mask[:seq_len, :seq_len]
|
|
masked_scores = scores.masked_fill(~mask, float('-inf'))
|
|
|
|
# Проверяем, что будущие токены замаскированы
|
|
for i in range(seq_len):
|
|
for j in range(i+1, seq_len):
|
|
assert masked_scores[0, i, j] == float('-inf'), f"Position {i},{j} not masked"
|
|
print("Causal attention test passed")
|
|
|
|
def test_attention_sequence_length_limit(attention):
|
|
with pytest.raises(ValueError):
|
|
x = torch.randn(1, 129, 64) # Превышаем max_seq_len
|
|
attention(x)
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(["-v", __file__])
|