mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(core): expand docstrings and add unit tests for RMSNorm
- docs: update/increase docstring detail for RMSNorm class and methods (motivation, formula, architecture, usage, references to LLaMA/PaLM/GPT) - test: add comprehensive unit tests for RMSNorm (shape/type preservation, rms scaling, gradients for input and weights, fp16, large eps stability) No code/API changes beyond docs and new tests.
This commit is contained in:
@@ -24,35 +24,63 @@ from typing import Optional
|
|||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
"""
|
"""
|
||||||
RMS Normalization (Root Mean Square Layer Normalization).
|
RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
|
||||||
|
|
||||||
Нормализует входные данные по последнему измерению используя среднеквадратичное
|
Назначение:
|
||||||
значение вместо среднего, как в стандартном LayerNorm.
|
-----------
|
||||||
|
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
|
||||||
|
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
|
||||||
|
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
|
||||||
|
|
||||||
Научная суть:
|
Мотивация и математика:
|
||||||
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms.
|
-----------------------
|
||||||
- Лучшая численная стабильность на больших моделях, меньше вычислений.
|
- Формула для одного слоя и вектора x:
|
||||||
- Применяется в LLaMA, PaLM и др.
|
rms = sqrt( mean( x ** 2 ) + eps )
|
||||||
|
out = w * ( x / rms )
|
||||||
|
где w — learnable scale, eps — небольшая константа для численной устойчивости.
|
||||||
|
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
|
||||||
|
|
||||||
Формула:
|
Аргументы конструктора:
|
||||||
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор)
|
-----------------------
|
||||||
|
dim : int
|
||||||
|
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
|
||||||
|
eps : float, default=1e-6
|
||||||
|
Малое значение для устойчивости (additive epsilon).
|
||||||
|
|
||||||
Args:
|
Особенности:
|
||||||
dim (int): размер последнего измерения (обычно emb_size)
|
------------
|
||||||
eps (float): для численной устойчивости
|
- Нет батч-нормализации, нет зависимости от размера батча.
|
||||||
|
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> norm = RMSNorm(emb_size=256)
|
||||||
|
>>> x = torch.randn(4, 10, 256)
|
||||||
|
>>> out = norm(x) # возвращает tensor той же формы
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
|
||||||
|
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
|
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||||
|
|
||||||
Пример:
|
|
||||||
>>> norm = RMSNorm(emb_size)
|
|
||||||
>>> out = norm(x)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
"""
|
"""
|
||||||
Инициализация RMSNorm слоя.
|
Инициализация RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dim: Размерность нормализуемого измерения
|
-----
|
||||||
eps: Малое значение для численной стабильности (по умолчанию 1e-6)
|
dim : int
|
||||||
|
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
|
||||||
|
eps : float
|
||||||
|
Малое значение для устойчивости (по умолчанию 1e-6).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаётся обучаемый scale weight w для каждой компоненты dim.
|
||||||
|
- Сохраняется параметр eps для добавления к RMS.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._eps = eps
|
self._eps = eps
|
||||||
@@ -60,16 +88,28 @@ class RMSNorm(nn.Module):
|
|||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Прямой проход через RMSNorm слой.
|
Прямой проход через RMSNorm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Входной тензор формы [..., dim]
|
-----
|
||||||
|
x : torch.Tensor
|
||||||
|
Входной тензор любого shape с последней размерностью dim.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Нормализованный тензор той же формы, что и входной
|
--------
|
||||||
|
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
|
||||||
|
|
||||||
|
Алгоритм:
|
||||||
|
---------
|
||||||
|
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
|
||||||
|
- Поделить x на rms
|
||||||
|
- Помасштабировать обучаемым весом w
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
-------
|
||||||
|
>>> norm = RMSNorm(256)
|
||||||
|
>>> out = norm(torch.randn(2, 10, 256))
|
||||||
|
|
||||||
Формула:
|
|
||||||
output = w * (x / sqrt(mean(x²) + eps))
|
|
||||||
"""
|
"""
|
||||||
# Вычисление RMS (Root Mean Square) по последнему измерению
|
# Вычисление RMS (Root Mean Square) по последнему измерению
|
||||||
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
|
||||||
|
|||||||
47
llm/tests/core/test_rms_norm.py
Normal file
47
llm/tests/core/test_rms_norm.py
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.rms_norm import RMSNorm
|
||||||
|
|
||||||
|
def test_rmsnorm_shape_preservation():
|
||||||
|
norm = RMSNorm(64)
|
||||||
|
x = torch.randn(3, 5, 64)
|
||||||
|
y = norm(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
|
||||||
|
def test_rmsnorm_dtype_and_device():
|
||||||
|
norm = RMSNorm(32)
|
||||||
|
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
|
||||||
|
y = norm(x)
|
||||||
|
assert y.dtype == torch.float64
|
||||||
|
assert y.device == x.device
|
||||||
|
|
||||||
|
def test_rmsnorm_mean_no_shift():
|
||||||
|
norm = RMSNorm(32)
|
||||||
|
x = torch.randn(3, 128, 32)
|
||||||
|
y = norm(x)
|
||||||
|
rms = torch.sqrt((y ** 2).mean(dim=-1))
|
||||||
|
w_mean = norm._w.mean().item()
|
||||||
|
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
|
||||||
|
|
||||||
|
def test_rmsnorm_backward():
|
||||||
|
norm = RMSNorm(16)
|
||||||
|
x = torch.randn(2, 15, 16, requires_grad=True)
|
||||||
|
y = norm(x)
|
||||||
|
loss = y.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert norm._w.grad is not None
|
||||||
|
|
||||||
|
def test_rmsnorm_fp16():
|
||||||
|
norm = RMSNorm(8).half()
|
||||||
|
x = torch.randn(2, 6, 8).half()
|
||||||
|
y = norm(x)
|
||||||
|
assert y.shape == x.shape
|
||||||
|
assert y.dtype == torch.float16
|
||||||
|
|
||||||
|
def test_rmsnorm_large_eps_stability():
|
||||||
|
norm = RMSNorm(16, eps=1)
|
||||||
|
x = torch.zeros(2, 5, 16)
|
||||||
|
y = norm(x)
|
||||||
|
assert not torch.isnan(y).any()
|
||||||
|
assert not torch.isinf(y).any()
|
||||||
Reference in New Issue
Block a user