mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 05:21:16 +00:00
56 lines
1.9 KiB
Python
56 lines
1.9 KiB
Python
|
|
import torch
|
|||
|
|
import pytest
|
|||
|
|
from llm.core.rope import RoPE
|
|||
|
|
|
|||
|
|
def test_rope_shapes_and_dtype():
|
|||
|
|
rope = RoPE(head_size=8, max_seq_len=32)
|
|||
|
|
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
|
|||
|
|
y = rope(x)
|
|||
|
|
assert y.shape == x.shape
|
|||
|
|
assert y.dtype == x.dtype
|
|||
|
|
|
|||
|
|
def test_rope_raises_on_bad_ndim():
|
|||
|
|
rope = RoPE(head_size=8, max_seq_len=16)
|
|||
|
|
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
|
|||
|
|
with pytest.raises(AssertionError):
|
|||
|
|
_ = rope(x)
|
|||
|
|
|
|||
|
|
def test_rope_preserves_norm():
|
|||
|
|
rope = RoPE(head_size=8, max_seq_len=16)
|
|||
|
|
x = torch.randn(2, 3, 7, 8)
|
|||
|
|
x_norm = x.norm(dim=-1)
|
|||
|
|
y = rope(x)
|
|||
|
|
y_norm = y.norm(dim=-1)
|
|||
|
|
# Нормы могут немного отличаться из-за float, сравниваем с допуском
|
|||
|
|
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
|
|||
|
|
|
|||
|
|
def test_rope_backward_pass():
|
|||
|
|
rope = RoPE(head_size=8, max_seq_len=16)
|
|||
|
|
x = torch.randn(2, 2, 8, 8, requires_grad=True)
|
|||
|
|
out = rope(x)
|
|||
|
|
loss = out.sum()
|
|||
|
|
loss.backward()
|
|||
|
|
assert x.grad is not None
|
|||
|
|
assert x.grad.shape == x.shape
|
|||
|
|
|
|||
|
|
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
|
|||
|
|
(1, 1, 4, 8),
|
|||
|
|
(2, 4, 16, 8),
|
|||
|
|
(3, 2, 7, 8),
|
|||
|
|
])
|
|||
|
|
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
|
|||
|
|
rope = RoPE(head_size=head_size, max_seq_len=32)
|
|||
|
|
x = torch.randn(batch, num_heads, seq_len, head_size)
|
|||
|
|
y = rope(x)
|
|||
|
|
assert y.shape == x.shape
|
|||
|
|
|
|||
|
|
def test_rope_start_pos():
|
|||
|
|
rope = RoPE(head_size=8, max_seq_len=32)
|
|||
|
|
x_full = torch.randn(1, 2, 8, 8)
|
|||
|
|
# Сравниваем участок результата для разных start_pos
|
|||
|
|
out1 = rope(x_full)
|
|||
|
|
out2 = rope(x_full, start_pos=2)
|
|||
|
|
assert not torch.allclose(out1, out2)
|
|||
|
|
# Для одинакового start_pos и x должны совпадать
|
|||
|
|
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))
|