diff --git a/llm/src/llm/core/rms_norm.py b/llm/src/llm/core/rms_norm.py index 6b2a5fe..52a4a53 100644 --- a/llm/src/llm/core/rms_norm.py +++ b/llm/src/llm/core/rms_norm.py @@ -24,35 +24,63 @@ from typing import Optional class RMSNorm(nn.Module): """ - RMS Normalization (Root Mean Square Layer Normalization). + RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm. - Нормализует входные данные по последнему измерению используя среднеквадратичное - значение вместо среднего, как в стандартном LayerNorm. + Назначение: + ----------- + - Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего. + - Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения. + - В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями. - Научная суть: - - Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms. - - Лучшая численная стабильность на больших моделях, меньше вычислений. - - Применяется в LLaMA, PaLM и др. + Мотивация и математика: + ----------------------- + - Формула для одного слоя и вектора x: + rms = sqrt( mean( x ** 2 ) + eps ) + out = w * ( x / rms ) + где w — learnable scale, eps — небольшая константа для численной устойчивости. + - Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах. - Формула: - RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор) + Аргументы конструктора: + ----------------------- + dim : int + Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head). + eps : float, default=1e-6 + Малое значение для устойчивости (additive epsilon). - Args: - dim (int): размер последнего измерения (обычно emb_size) - eps (float): для численной устойчивости + Особенности: + ------------ + - Нет батч-нормализации, нет зависимости от размера батча. + - Отлично подходит для больших моделей и автогерессии — меньше шуму от residual. + + Пример использования: + --------------------- + >>> norm = RMSNorm(emb_size=256) + >>> x = torch.randn(4, 10, 256) + >>> out = norm(x) # возвращает tensor той же формы + + References: + ----------- + - Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467 + - Применение в LLaMA: https://arxiv.org/abs/2302.13971 + - HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py - Пример: - >>> norm = RMSNorm(emb_size) - >>> out = norm(x) """ def __init__(self, dim: int, eps: float = 1e-6): """ - Инициализация RMSNorm слоя. + Инициализация RMSNorm. Args: - dim: Размерность нормализуемого измерения - eps: Малое значение для численной стабильности (по умолчанию 1e-6) + ----- + dim : int + Последнее нормализуемое измерение (обычно размерность embedding или hidden). + eps : float + Малое значение для устойчивости (по умолчанию 1e-6). + + Внутри: + ------- + - Создаётся обучаемый scale weight w для каждой компоненты dim. + - Сохраняется параметр eps для добавления к RMS. """ super().__init__() self._eps = eps @@ -60,16 +88,28 @@ class RMSNorm(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Прямой проход через RMSNorm слой. - + Прямой проход через RMSNorm. + Args: - x: Входной тензор формы [..., dim] - + ----- + x : torch.Tensor + Входной тензор любого shape с последней размерностью dim. + Returns: - Нормализованный тензор той же формы, что и входной - - Формула: - output = w * (x / sqrt(mean(x²) + eps)) + -------- + torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении. + + Алгоритм: + --------- + - Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps ) + - Поделить x на rms + - Помасштабировать обучаемым весом w + + Пример: + ------- + >>> norm = RMSNorm(256) + >>> out = norm(torch.randn(2, 10, 256)) + """ # Вычисление RMS (Root Mean Square) по последнему измерению rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5 diff --git a/llm/tests/core/test_rms_norm.py b/llm/tests/core/test_rms_norm.py new file mode 100644 index 0000000..55f47f0 --- /dev/null +++ b/llm/tests/core/test_rms_norm.py @@ -0,0 +1,47 @@ +import torch +import pytest +from llm.core.rms_norm import RMSNorm + +def test_rmsnorm_shape_preservation(): + norm = RMSNorm(64) + x = torch.randn(3, 5, 64) + y = norm(x) + assert y.shape == x.shape + +def test_rmsnorm_dtype_and_device(): + norm = RMSNorm(32) + x = torch.randn(8, 32, device='cpu', dtype=torch.float64) + y = norm(x) + assert y.dtype == torch.float64 + assert y.device == x.device + +def test_rmsnorm_mean_no_shift(): + norm = RMSNorm(32) + x = torch.randn(3, 128, 32) + y = norm(x) + rms = torch.sqrt((y ** 2).mean(dim=-1)) + w_mean = norm._w.mean().item() + assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2) + +def test_rmsnorm_backward(): + norm = RMSNorm(16) + x = torch.randn(2, 15, 16, requires_grad=True) + y = norm(x) + loss = y.sum() + loss.backward() + assert x.grad is not None + assert norm._w.grad is not None + +def test_rmsnorm_fp16(): + norm = RMSNorm(8).half() + x = torch.randn(2, 6, 8).half() + y = norm(x) + assert y.shape == x.shape + assert y.dtype == torch.float16 + +def test_rmsnorm_large_eps_stability(): + norm = RMSNorm(16, eps=1) + x = torch.zeros(2, 5, 16) + y = norm(x) + assert not torch.isnan(y).any() + assert not torch.isinf(y).any()