From 516f9580fb26d7bca42b23e1bb43f970a2352fdf Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Thu, 16 Oct 2025 15:09:09 +0300 Subject: [PATCH] docs(core): add docstrings and unit tests for SwiGLU block - docs: rewrite and expand docstrings for SwiGLU class and forward method (motivation, math, architecture, usage, references to LLaMA/Mistral/PaLM) - test: add unit tests for SwiGLU (shape, dtype, gradients, output range, fp16 support, reproducibility) - strictly doc/tests, no logic or API changes This improves transparency and reliability for gated FFN blocks in transformer architectures. --- llm/src/llm/core/swi_glu.py | 79 +++++++++++++++++++++++----------- llm/tests/core/test_swi_glu.py | 39 +++++++++++++++++ 2 files changed, 93 insertions(+), 25 deletions(-) create mode 100644 llm/tests/core/test_swi_glu.py diff --git a/llm/src/llm/core/swi_glu.py b/llm/src/llm/core/swi_glu.py index d0820b0..93e9f64 100644 --- a/llm/src/llm/core/swi_glu.py +++ b/llm/src/llm/core/swi_glu.py @@ -24,31 +24,55 @@ from .silu import SiLU class SwiGLU(nn.Module): """ - SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM). + SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral). - Реализация SwiGLU активационной функции. + Назначение: + ----------- + - Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком). + - Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока. + - Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral. - Состоит из трех линейных слоев и активации SiLU: - 1. Gate слой + SiLU активация - 2. Up слой (линейное преобразование) - 3. Element-wise multiplication gate и up - 4. Down слой (линейная проекция) + Формула и математика: + --------------------- + Пусть x — вход, then: - Научная суть: - - Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации. - - Дает надежную гладкую активацию, хорошо работает на больших масштабах. - - Статья: "GLU Variants Improve Transformer" (Shazeer, 2020). + SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d - Формула: - SwiGLU(x) = SiLU(W_g·x) * (W_u·x) - где SiLU(x) = x*sigma(x) + Типовая реализация (как здесь, по LLAMA/Mistral): + gate = SiLU(Linear_gate(x)) # фитчерный \"gate\" + up = Linear_up(x) # пропускная ветка + mult = gate * up # поэлементное умножение (контроль информации) + out = Linear_down(mult) # финальная проекция + out = Dropout(out) # регуляризация + + Почему это работает: + ------------------- + - Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space. + - SiLU обеспечивает smooth градиенты (лучше для обучения LLM). + - В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU. + + Параметры конструктора: + ----------------------- + emb_size: int + Размерность входного (и выходного) признакового пространства. + dropout: float + Dropout после final linear (обычно около 0.1). + + Пример использования: + --------------------- + >>> block = SwiGLU(emb_size=512, dropout=0.1) + >>> x = torch.randn(8, 16, 512) + >>> y = block(x) + >>> print(y.shape) # torch.Size([8, 16, 512]) + + References: + ----------- + - Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202 + - PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1) + - LLaMA: https://arxiv.org/abs/2302.13971 + - Mistral: https://arxiv.org/abs/2310.06825 + - HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama - Args: - emb_size (int): размер входов/выходов - dropout (float): после выходной проекции - Пример: - >>> ff = SwiGLU(emb_size=512, dropout=0.1) - >>> y = ff(torch.randn(2,10,512)) """ def __init__(self, emb_size: int, dropout: float = 0.1): @@ -68,19 +92,24 @@ class SwiGLU(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """ - Прямой проход через SwiGLU слой. + Прямой проход через блок SwiGLU. Args: - x: Входной тензор формы [batch_size, seq_len, emb_size] + ----- + x : torch.Tensor + Входной тензор формы [batch_size, seq_len, emb_size] Returns: - Выходной тензор формы [batch_size, seq_len, emb_size] + -------- + torch.Tensor той же формы Алгоритм: + --------- 1. gate = SiLU(linear_gate(x)) 2. up = linear_up(x) - 3. output = linear_down(gate ⊙ up) - 4. apply dropout + 3. mult = gate * up # поэлементно + 4. out = linear_down(mult) + 5. out = dropout(out) """ # Gate ветвь: линейное преобразование + активация gate_out = self._gate(x) # [batch, seq, 4*emb] diff --git a/llm/tests/core/test_swi_glu.py b/llm/tests/core/test_swi_glu.py new file mode 100644 index 0000000..e6f4d3f --- /dev/null +++ b/llm/tests/core/test_swi_glu.py @@ -0,0 +1,39 @@ +import torch +import pytest +from llm.core.swi_glu import SwiGLU + +def test_swiglu_shape_and_dtype(): + swiglu = SwiGLU(emb_size=32, dropout=0.1) + x = torch.randn(4, 10, 32) + y = swiglu(x) + assert y.shape == x.shape + assert y.dtype == x.dtype + +def test_swiglu_forward_range(): + swiglu = SwiGLU(emb_size=16, dropout=0.0) + x = torch.randn(3, 7, 16) + y = swiglu(x) + assert y.abs().max() < 20 + +def test_swiglu_gradients(): + swiglu = SwiGLU(emb_size=8, dropout=0.0) + x = torch.randn(2, 5, 8, requires_grad=True) + out = swiglu(x) + loss = out.pow(2).sum() + loss.backward() + assert x.grad is not None + assert x.grad.shape == x.shape + +def test_swiglu_fp16(): + swiglu = SwiGLU(emb_size=16, dropout=0.0).half() + x = torch.randn(1, 8, 16).half() + y = swiglu(x) + assert y.shape == x.shape + assert y.dtype == torch.float16 + +def test_swiglu_reproducibility(): + swiglu = SwiGLU(emb_size=8, dropout=0.0) + x = torch.ones(2, 4, 8) + y1 = swiglu(x) + y2 = swiglu(x) + assert torch.allclose(y1, y2)