mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
91 lines
2.5 KiB
Python
91 lines
2.5 KiB
Python
import torch
|
||
import pytest
|
||
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||
|
||
@pytest.fixture
|
||
def sample_input():
|
||
"""Фикстура с тестовыми входными данными"""
|
||
batch_size = 2
|
||
seq_len = 10
|
||
emb_size = 64
|
||
return torch.randn(batch_size, seq_len, emb_size)
|
||
|
||
def test_initialization():
|
||
"""Тест инициализации с правильными параметрами"""
|
||
mha = MultiHeadAttention(
|
||
num_heads=8,
|
||
emb_size=64,
|
||
head_size=32,
|
||
max_seq_len=100,
|
||
dropout=0.1
|
||
)
|
||
|
||
assert len(mha._heads) == 8
|
||
assert mha._layer.in_features == 8 * 32
|
||
assert mha._layer.out_features == 64
|
||
assert mha._dropout.p == 0.1
|
||
|
||
def test_forward_pass(sample_input):
|
||
"""Тест прямого прохода с сохранением размерности"""
|
||
mha = MultiHeadAttention(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=50
|
||
)
|
||
|
||
output = mha(sample_input)
|
||
assert output.shape == sample_input.shape
|
||
|
||
def test_dropout_effect(sample_input):
|
||
"""Тест влияния dropout на выход"""
|
||
mha_with_dropout = MultiHeadAttention(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=50,
|
||
dropout=0.5
|
||
)
|
||
|
||
mha_without_dropout = MultiHeadAttention(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=50,
|
||
dropout=0.0
|
||
)
|
||
|
||
output1 = mha_with_dropout(sample_input)
|
||
output2 = mha_without_dropout(sample_input)
|
||
assert not torch.allclose(output1, output2)
|
||
|
||
def test_gradient_flow(sample_input):
|
||
"""Тест корректности обратного распространения"""
|
||
mha = MultiHeadAttention(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=50
|
||
)
|
||
|
||
sample_input.requires_grad_(True)
|
||
output = mha(sample_input)
|
||
output.sum().backward()
|
||
assert sample_input.grad is not None
|
||
|
||
def test_mask_support(sample_input):
|
||
"""Тест поддержки масок (должен проходить даже без реализации)"""
|
||
mask = torch.ones(sample_input.shape[:2])
|
||
mha = MultiHeadAttention(
|
||
num_heads=4,
|
||
emb_size=64,
|
||
head_size=16,
|
||
max_seq_len=50
|
||
)
|
||
|
||
try:
|
||
output = mha(sample_input, mask=mask)
|
||
assert output.shape == sample_input.shape
|
||
except Exception as e:
|
||
pytest.fail(f"Mask handling failed: {e}")
|