Files
llm-arch-research/llm/tests/core/test_rope.py

56 lines
1.9 KiB
Python
Raw Normal View History

refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention - refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
2025-10-15 10:59:56 +03:00
import torch
import pytest
from llm.core.rope import RoPE
def test_rope_shapes_and_dtype():
rope = RoPE(head_size=8, max_seq_len=32)
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
y = rope(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_rope_raises_on_bad_ndim():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
with pytest.raises(AssertionError):
_ = rope(x)
def test_rope_preserves_norm():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 3, 7, 8)
x_norm = x.norm(dim=-1)
y = rope(x)
y_norm = y.norm(dim=-1)
# Нормы могут немного отличаться из-за float, сравниваем с допуском
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
def test_rope_backward_pass():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 2, 8, 8, requires_grad=True)
out = rope(x)
loss = out.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
(1, 1, 4, 8),
(2, 4, 16, 8),
(3, 2, 7, 8),
])
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
rope = RoPE(head_size=head_size, max_seq_len=32)
x = torch.randn(batch, num_heads, seq_len, head_size)
y = rope(x)
assert y.shape == x.shape
def test_rope_start_pos():
rope = RoPE(head_size=8, max_seq_len=32)
x_full = torch.randn(1, 2, 8, 8)
# Сравниваем участок результата для разных start_pos
out1 = rope(x_full)
out2 = rope(x_full, start_pos=2)
assert not torch.allclose(out1, out2)
# Для одинакового start_pos и x должны совпадать
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))