mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов) - docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования - test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами) - chore: удалён неиспользуемый модуль core/head_attention.py - fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
import torch
|
||
import pytest
|
||
from llm.core.rope import RoPE
|
||
|
||
def test_rope_shapes_and_dtype():
|
||
rope = RoPE(head_size=8, max_seq_len=32)
|
||
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
||
y = rope(x)
|
||
assert y.shape == x.shape
|
||
assert y.dtype == x.dtype
|
||
|
||
def test_rope_raises_on_bad_ndim():
|
||
rope = RoPE(head_size=8, max_seq_len=16)
|
||
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
||
with pytest.raises(AssertionError):
|
||
_ = rope(x)
|
||
|
||
def test_rope_preserves_norm():
|
||
rope = RoPE(head_size=8, max_seq_len=16)
|
||
x = torch.randn(2, 3, 7, 8)
|
||
x_norm = x.norm(dim=-1)
|
||
y = rope(x)
|
||
y_norm = y.norm(dim=-1)
|
||
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
||
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
||
|
||
def test_rope_backward_pass():
|
||
rope = RoPE(head_size=8, max_seq_len=16)
|
||
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
||
out = rope(x)
|
||
loss = out.sum()
|
||
loss.backward()
|
||
assert x.grad is not None
|
||
assert x.grad.shape == x.shape
|
||
|
||
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
||
(1, 1, 4, 8),
|
||
(2, 4, 16, 8),
|
||
(3, 2, 7, 8),
|
||
])
|
||
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
||
rope = RoPE(head_size=head_size, max_seq_len=32)
|
||
x = torch.randn(batch, num_heads, seq_len, head_size)
|
||
y = rope(x)
|
||
assert y.shape == x.shape
|
||
|
||
def test_rope_start_pos():
|
||
rope = RoPE(head_size=8, max_seq_len=32)
|
||
x_full = torch.randn(1, 2, 8, 8)
|
||
# Сравниваем участок результата для разных start_pos
|
||
out1 = rope(x_full)
|
||
out2 = rope(x_full, start_pos=2)
|
||
assert not torch.allclose(out1, out2)
|
||
# Для одинакового start_pos и x должны совпадать
|
||
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|