mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(core): expand docstrings and add unit tests for RMSNorm
- docs: update/increase docstring detail for RMSNorm class and methods (motivation, formula, architecture, usage, references to LLaMA/PaLM/GPT) - test: add comprehensive unit tests for RMSNorm (shape/type preservation, rms scaling, gradients for input and weights, fp16, large eps stability) No code/API changes beyond docs and new tests.
This commit is contained in:
@@ -24,35 +24,63 @@ from typing import Optional
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
"""
|
||||
RMS Normalization (Root Mean Square Layer Normalization).
|
||||
RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
|
||||
|
||||
Нормализует входные данные по последнему измерению используя среднеквадратичное
|
||||
значение вместо среднего, как в стандартном LayerNorm.
|
||||
Назначение:
|
||||
-----------
|
||||
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
|
||||
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
|
||||
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
|
||||
|
||||
Научная суть:
|
||||
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms.
|
||||
- Лучшая численная стабильность на больших моделях, меньше вычислений.
|
||||
- Применяется в LLaMA, PaLM и др.
|
||||
Мотивация и математика:
|
||||
-----------------------
|
||||
- Формула для одного слоя и вектора x:
|
||||
rms = sqrt( mean( x ** 2 ) + eps )
|
||||
out = w * ( x / rms )
|
||||
где w — learnable scale, eps — небольшая константа для численной устойчивости.
|
||||
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
|
||||
|
||||
Формула:
|
||||
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор)
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
dim : int
|
||||
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
|
||||
eps : float, default=1e-6
|
||||
Малое значение для устойчивости (additive epsilon).
|
||||
|
||||
Args:
|
||||
dim (int): размер последнего измерения (обычно emb_size)
|
||||
eps (float): для численной устойчивости
|
||||
Особенности:
|
||||
------------
|
||||
- Нет батч-нормализации, нет зависимости от размера батча.
|
||||
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> norm = RMSNorm(emb_size=256)
|
||||
>>> x = torch.randn(4, 10, 256)
|
||||
>>> out = norm(x) # возвращает tensor той же формы
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
|
||||
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
|
||||
Пример:
|
||||
>>> norm = RMSNorm(emb_size)
|
||||
>>> out = norm(x)
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
"""
|
||||
Инициализация RMSNorm слоя.
|
||||
Инициализация RMSNorm.
|
||||
|
||||
Args:
|
||||
dim: Размерность нормализуемого измерения
|
||||
eps: Малое значение для численной стабильности (по умолчанию 1e-6)
|
||||
-----
|
||||
dim : int
|
||||
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
|
||||
eps : float
|
||||
Малое значение для устойчивости (по умолчанию 1e-6).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Создаётся обучаемый scale weight w для каждой компоненты dim.
|
||||
- Сохраняется параметр eps для добавления к RMS.
|
||||
"""
|
||||
super().__init__()
|
||||
self._eps = eps
|
||||
@@ -60,16 +88,28 @@ class RMSNorm(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через RMSNorm слой.
|
||||
|
||||
Прямой проход через RMSNorm.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [..., dim]
|
||||
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор любого shape с последней размерностью dim.
|
||||
|
||||
Returns:
|
||||
Нормализованный тензор той же формы, что и входной
|
||||
|
||||
Формула:
|
||||
output = w * (x / sqrt(mean(x²) + eps))
|
||||
--------
|
||||
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
|
||||
- Поделить x на rms
|
||||
- Помасштабировать обучаемым весом w
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> norm = RMSNorm(256)
|
||||
>>> out = norm(torch.randn(2, 10, 256))
|
||||
|
||||
"""
|
||||
# Вычисление RMS (Root Mean Square) по последнему измерению
|
||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||
|
||||
47
llm/tests/core/test_rms_norm.py
Normal file
47
llm/tests/core/test_rms_norm.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.rms_norm import RMSNorm
|
||||
|
||||
def test_rmsnorm_shape_preservation():
|
||||
norm = RMSNorm(64)
|
||||
x = torch.randn(3, 5, 64)
|
||||
y = norm(x)
|
||||
assert y.shape == x.shape
|
||||
|
||||
def test_rmsnorm_dtype_and_device():
|
||||
norm = RMSNorm(32)
|
||||
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
|
||||
y = norm(x)
|
||||
assert y.dtype == torch.float64
|
||||
assert y.device == x.device
|
||||
|
||||
def test_rmsnorm_mean_no_shift():
|
||||
norm = RMSNorm(32)
|
||||
x = torch.randn(3, 128, 32)
|
||||
y = norm(x)
|
||||
rms = torch.sqrt((y ** 2).mean(dim=-1))
|
||||
w_mean = norm._w.mean().item()
|
||||
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
|
||||
|
||||
def test_rmsnorm_backward():
|
||||
norm = RMSNorm(16)
|
||||
x = torch.randn(2, 15, 16, requires_grad=True)
|
||||
y = norm(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert norm._w.grad is not None
|
||||
|
||||
def test_rmsnorm_fp16():
|
||||
norm = RMSNorm(8).half()
|
||||
x = torch.randn(2, 6, 8).half()
|
||||
y = norm(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == torch.float16
|
||||
|
||||
def test_rmsnorm_large_eps_stability():
|
||||
norm = RMSNorm(16, eps=1)
|
||||
x = torch.zeros(2, 5, 16)
|
||||
y = norm(x)
|
||||
assert not torch.isnan(y).any()
|
||||
assert not torch.isinf(y).any()
|
||||
Reference in New Issue
Block a user