mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Добавление тестов для MultiHeadAttention + финальные правки
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
import math
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
|
|
||||||
class HeadAttention(nn.Module):
|
class HeadAttention(nn.Module):
|
||||||
@@ -45,6 +46,36 @@ class HeadAttention(nn.Module):
|
|||||||
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
|
||||||
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
|
||||||
|
|
||||||
|
def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Возвращает матрицу весов внимания без умножения на V.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor: Матрица весов внимания формы [batch_size, seq_len, seq_len]
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
|
||||||
|
>>> x = torch.randn(1, 10, 64)
|
||||||
|
>>> weights = attention.get_attention_weights(x) # [1, 10, 10]
|
||||||
|
"""
|
||||||
|
seq_len = x.shape[1]
|
||||||
|
if seq_len > self._max_seq_len:
|
||||||
|
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
|
||||||
|
|
||||||
|
# Вычисляем Q и K
|
||||||
|
k = self._k(x) # [B, T, hs]
|
||||||
|
q = self._q(x) # [B, T, hs]
|
||||||
|
|
||||||
|
# Вычисляем scores и применяем маску
|
||||||
|
scores = q @ k.transpose(-2, -1) / math.sqrt(self._head_size)
|
||||||
|
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
|
||||||
|
|
||||||
|
# Возвращаем нормализованные веса
|
||||||
|
return torch.softmax(scores, dim=-1)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Прямой проход через слой внимания.
|
Прямой проход через слой внимания.
|
||||||
|
|||||||
90
tests/test_multi_head_attention.py
Normal file
90
tests/test_multi_head_attention.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_input():
|
||||||
|
"""Фикстура с тестовыми входными данными"""
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 10
|
||||||
|
emb_size = 64
|
||||||
|
return torch.randn(batch_size, seq_len, emb_size)
|
||||||
|
|
||||||
|
def test_initialization():
|
||||||
|
"""Тест инициализации с правильными параметрами"""
|
||||||
|
mha = MultiHeadAttention(
|
||||||
|
num_heads=8,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=32,
|
||||||
|
max_seq_len=100,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mha._heads) == 8
|
||||||
|
assert mha._layer.in_features == 8 * 32
|
||||||
|
assert mha._layer.out_features == 64
|
||||||
|
assert mha._dropout.p == 0.1
|
||||||
|
|
||||||
|
def test_forward_pass(sample_input):
|
||||||
|
"""Тест прямого прохода с сохранением размерности"""
|
||||||
|
mha = MultiHeadAttention(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=50
|
||||||
|
)
|
||||||
|
|
||||||
|
output = mha(sample_input)
|
||||||
|
assert output.shape == sample_input.shape
|
||||||
|
|
||||||
|
def test_dropout_effect(sample_input):
|
||||||
|
"""Тест влияния dropout на выход"""
|
||||||
|
mha_with_dropout = MultiHeadAttention(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=50,
|
||||||
|
dropout=0.5
|
||||||
|
)
|
||||||
|
|
||||||
|
mha_without_dropout = MultiHeadAttention(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=50,
|
||||||
|
dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
output1 = mha_with_dropout(sample_input)
|
||||||
|
output2 = mha_without_dropout(sample_input)
|
||||||
|
assert not torch.allclose(output1, output2)
|
||||||
|
|
||||||
|
def test_gradient_flow(sample_input):
|
||||||
|
"""Тест корректности обратного распространения"""
|
||||||
|
mha = MultiHeadAttention(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=50
|
||||||
|
)
|
||||||
|
|
||||||
|
sample_input.requires_grad_(True)
|
||||||
|
output = mha(sample_input)
|
||||||
|
output.sum().backward()
|
||||||
|
assert sample_input.grad is not None
|
||||||
|
|
||||||
|
def test_mask_support(sample_input):
|
||||||
|
"""Тест поддержки масок (должен проходить даже без реализации)"""
|
||||||
|
mask = torch.ones(sample_input.shape[:2])
|
||||||
|
mha = MultiHeadAttention(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=50
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
output = mha(sample_input, mask=mask)
|
||||||
|
assert output.shape == sample_input.shape
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Mask handling failed: {e}")
|
||||||
Reference in New Issue
Block a user