diff --git a/simple_llm/transformer/head_attention.py b/simple_llm/transformer/head_attention.py index ac6f061..3366f7b 100644 --- a/simple_llm/transformer/head_attention.py +++ b/simple_llm/transformer/head_attention.py @@ -1,6 +1,7 @@ import torch from torch import nn import torch.nn.functional as F +import math from math import sqrt class HeadAttention(nn.Module): @@ -45,6 +46,36 @@ class HeadAttention(nn.Module): mask = torch.tril(torch.ones(max_seq_len, max_seq_len)) self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte()) + def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor: + """ + Возвращает матрицу весов внимания без умножения на V. + + Аргументы: + x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size] + + Возвращает: + torch.Tensor: Матрица весов внимания формы [batch_size, seq_len, seq_len] + + Пример: + >>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128) + >>> x = torch.randn(1, 10, 64) + >>> weights = attention.get_attention_weights(x) # [1, 10, 10] + """ + seq_len = x.shape[1] + if seq_len > self._max_seq_len: + raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}") + + # Вычисляем Q и K + k = self._k(x) # [B, T, hs] + q = self._q(x) # [B, T, hs] + + # Вычисляем scores и применяем маску + scores = q @ k.transpose(-2, -1) / math.sqrt(self._head_size) + scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf')) + + # Возвращаем нормализованные веса + return torch.softmax(scores, dim=-1) + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Прямой проход через слой внимания. diff --git a/tests/test_multi_head_attention.py b/tests/test_multi_head_attention.py new file mode 100644 index 0000000..bdba760 --- /dev/null +++ b/tests/test_multi_head_attention.py @@ -0,0 +1,90 @@ +import torch +import pytest +from simple_llm.transformer.multi_head_attention import MultiHeadAttention + +@pytest.fixture +def sample_input(): + """Фикстура с тестовыми входными данными""" + batch_size = 2 + seq_len = 10 + emb_size = 64 + return torch.randn(batch_size, seq_len, emb_size) + +def test_initialization(): + """Тест инициализации с правильными параметрами""" + mha = MultiHeadAttention( + num_heads=8, + emb_size=64, + head_size=32, + max_seq_len=100, + dropout=0.1 + ) + + assert len(mha._heads) == 8 + assert mha._layer.in_features == 8 * 32 + assert mha._layer.out_features == 64 + assert mha._dropout.p == 0.1 + +def test_forward_pass(sample_input): + """Тест прямого прохода с сохранением размерности""" + mha = MultiHeadAttention( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=50 + ) + + output = mha(sample_input) + assert output.shape == sample_input.shape + +def test_dropout_effect(sample_input): + """Тест влияния dropout на выход""" + mha_with_dropout = MultiHeadAttention( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=50, + dropout=0.5 + ) + + mha_without_dropout = MultiHeadAttention( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=50, + dropout=0.0 + ) + + output1 = mha_with_dropout(sample_input) + output2 = mha_without_dropout(sample_input) + assert not torch.allclose(output1, output2) + +def test_gradient_flow(sample_input): + """Тест корректности обратного распространения""" + mha = MultiHeadAttention( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=50 + ) + + sample_input.requires_grad_(True) + output = mha(sample_input) + output.sum().backward() + assert sample_input.grad is not None + +def test_mask_support(sample_input): + """Тест поддержки масок (должен проходить даже без реализации)""" + mask = torch.ones(sample_input.shape[:2]) + mha = MultiHeadAttention( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=50 + ) + + try: + output = mha(sample_input, mask=mask) + assert output.shape == sample_input.shape + except Exception as e: + pytest.fail(f"Mask handling failed: {e}")