docs(core): expand docstrings and add unit tests for RMSNorm

- docs: update/increase docstring detail for RMSNorm class and methods (motivation, formula, architecture, usage, references to LLaMA/PaLM/GPT)
- test: add comprehensive unit tests for RMSNorm (shape/type preservation, rms scaling, gradients for input and weights, fp16, large eps stability)

No code/API changes beyond docs and new tests.
This commit is contained in:
Sergey Penkovsky
2025-10-16 14:37:25 +03:00
parent 8018efae2a
commit 6efc946027
2 changed files with 113 additions and 26 deletions

View File

@@ -24,35 +24,63 @@ from typing import Optional
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
""" """
RMS Normalization (Root Mean Square Layer Normalization). RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
Нормализует входные данные по последнему измерению используя среднеквадратичное Назначение:
значение вместо среднего, как в стандартном LayerNorm. -----------
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
Научная суть: Мотивация и математика:
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms. -----------------------
- Лучшая численная стабильность на больших моделях, меньше вычислений. - Формула для одного слоя и вектора x:
- Применяется в LLaMA, PaLM и др. rms = sqrt( mean( x ** 2 ) + eps )
out = w * ( x / rms )
где w — learnable scale, eps — небольшая константа для численной устойчивости.
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
Формула: Аргументы конструктора:
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор) -----------------------
dim : int
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
eps : float, default=1e-6
Малое значение для устойчивости (additive epsilon).
Args: Особенности:
dim (int): размер последнего измерения (обычно emb_size) ------------
eps (float): для численной устойчивости - Нет батч-нормализации, нет зависимости от размера батча.
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
Пример использования:
---------------------
>>> norm = RMSNorm(emb_size=256)
>>> x = torch.randn(4, 10, 256)
>>> out = norm(x) # возвращает tensor той же формы
References:
-----------
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Пример:
>>> norm = RMSNorm(emb_size)
>>> out = norm(x)
""" """
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
""" """
Инициализация RMSNorm слоя. Инициализация RMSNorm.
Args: Args:
dim: Размерность нормализуемого измерения -----
eps: Малое значение для численной стабильности (по умолчанию 1e-6) dim : int
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
eps : float
Малое значение для устойчивости (по умолчанию 1e-6).
Внутри:
-------
- Создаётся обучаемый scale weight w для каждой компоненты dim.
- Сохраняется параметр eps для добавления к RMS.
""" """
super().__init__() super().__init__()
self._eps = eps self._eps = eps
@@ -60,16 +88,28 @@ class RMSNorm(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
Прямой проход через RMSNorm слой. Прямой проход через RMSNorm.
Args: Args:
x: Входной тензор формы [..., dim] -----
x : torch.Tensor
Входной тензор любого shape с последней размерностью dim.
Returns: Returns:
Нормализованный тензор той же формы, что и входной --------
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
Формула:
output = w * (x / sqrt(mean(x²) + eps)) Алгоритм:
---------
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
- Поделить x на rms
- Помасштабировать обучаемым весом w
Пример:
-------
>>> norm = RMSNorm(256)
>>> out = norm(torch.randn(2, 10, 256))
""" """
# Вычисление RMS (Root Mean Square) по последнему измерению # Вычисление RMS (Root Mean Square) по последнему измерению
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5 rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5

View File

@@ -0,0 +1,47 @@
import torch
import pytest
from llm.core.rms_norm import RMSNorm
def test_rmsnorm_shape_preservation():
norm = RMSNorm(64)
x = torch.randn(3, 5, 64)
y = norm(x)
assert y.shape == x.shape
def test_rmsnorm_dtype_and_device():
norm = RMSNorm(32)
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
y = norm(x)
assert y.dtype == torch.float64
assert y.device == x.device
def test_rmsnorm_mean_no_shift():
norm = RMSNorm(32)
x = torch.randn(3, 128, 32)
y = norm(x)
rms = torch.sqrt((y ** 2).mean(dim=-1))
w_mean = norm._w.mean().item()
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
def test_rmsnorm_backward():
norm = RMSNorm(16)
x = torch.randn(2, 15, 16, requires_grad=True)
y = norm(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert norm._w.grad is not None
def test_rmsnorm_fp16():
norm = RMSNorm(8).half()
x = torch.randn(2, 6, 8).half()
y = norm(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_rmsnorm_large_eps_stability():
norm = RMSNorm(16, eps=1)
x = torch.zeros(2, 5, 16)
y = norm(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()