mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention
- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
This commit is contained in:
@@ -19,7 +19,7 @@ class TestMultiHeadAttention:
|
||||
assert attention is not None
|
||||
|
||||
# Check internal attributes
|
||||
assert len(attention._heads) == num_heads
|
||||
assert attention._num_heads == num_heads
|
||||
assert attention._layer.in_features == embed_dim
|
||||
assert attention._layer.out_features == embed_dim
|
||||
|
||||
@@ -102,8 +102,10 @@ class TestMultiHeadAttention:
|
||||
|
||||
# Check that gradients are computed for learnable parameters
|
||||
assert attention._layer.weight.grad is not None
|
||||
if len(attention._heads) > 0:
|
||||
assert attention._heads[0]._q.weight.grad is not None
|
||||
# Проверяем, что также у градиентов весов q/k/v есть значения
|
||||
assert attention._q.weight.grad is not None
|
||||
assert attention._k.weight.grad is not None
|
||||
assert attention._v.weight.grad is not None
|
||||
|
||||
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
|
||||
"""Test that MultiHeadAttention works on correct device."""
|
||||
|
||||
55
llm/tests/core/test_rope.py
Normal file
55
llm/tests/core/test_rope.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
def test_rope_shapes_and_dtype():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_rope_raises_on_bad_ndim():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||||
with pytest.raises(AssertionError):
|
||||
_ = rope(x)
|
||||
|
||||
def test_rope_preserves_norm():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 3, 7, 8)
|
||||
x_norm = x.norm(dim=-1)
|
||||
y = rope(x)
|
||||
y_norm = y.norm(dim=-1)
|
||||
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||||
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||||
|
||||
def test_rope_backward_pass():
|
||||
rope = RoPE(head_size=8, max_seq_len=16)
|
||||
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||||
out = rope(x)
|
||||
loss = out.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||||
(1, 1, 4, 8),
|
||||
(2, 4, 16, 8),
|
||||
(3, 2, 7, 8),
|
||||
])
|
||||
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||||
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||||
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||||
y = rope(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_rope_start_pos():
|
||||
rope = RoPE(head_size=8, max_seq_len=32)
|
||||
x_full = torch.randn(1, 2, 8, 8)
|
||||
# Сравниваем участок результата для разных start_pos
|
||||
out1 = rope(x_full)
|
||||
out2 = rope(x_full, start_pos=2)
|
||||
assert not torch.allclose(out1, out2)
|
||||
# Для одинакового start_pos и x должны совпадать
|
||||
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|
||||
Reference in New Issue
Block a user