docs(core): expand docstrings and add unit tests for RMSNorm

- docs: update/increase docstring detail for RMSNorm class and methods (motivation, formula, architecture, usage, references to LLaMA/PaLM/GPT)
- test: add comprehensive unit tests for RMSNorm (shape/type preservation, rms scaling, gradients for input and weights, fp16, large eps stability)

No code/API changes beyond docs and new tests.
This commit is contained in:
Sergey Penkovsky
2025-10-16 14:37:25 +03:00
parent 8018efae2a
commit 6efc946027
2 changed files with 113 additions and 26 deletions

View File

@@ -0,0 +1,47 @@
import torch
import pytest
from llm.core.rms_norm import RMSNorm
def test_rmsnorm_shape_preservation():
norm = RMSNorm(64)
x = torch.randn(3, 5, 64)
y = norm(x)
assert y.shape == x.shape
def test_rmsnorm_dtype_and_device():
norm = RMSNorm(32)
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
y = norm(x)
assert y.dtype == torch.float64
assert y.device == x.device
def test_rmsnorm_mean_no_shift():
norm = RMSNorm(32)
x = torch.randn(3, 128, 32)
y = norm(x)
rms = torch.sqrt((y ** 2).mean(dim=-1))
w_mean = norm._w.mean().item()
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
def test_rmsnorm_backward():
norm = RMSNorm(16)
x = torch.randn(2, 15, 16, requires_grad=True)
y = norm(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert norm._w.grad is not None
def test_rmsnorm_fp16():
norm = RMSNorm(8).half()
x = torch.randn(2, 6, 8).half()
y = norm(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_rmsnorm_large_eps_stability():
norm = RMSNorm(16, eps=1)
x = torch.zeros(2, 5, 16)
y = norm(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()