docs(core): add docstrings and unit tests for SwiGLU block

- docs: rewrite and expand docstrings for SwiGLU class and forward method (motivation, math, architecture, usage, references to LLaMA/Mistral/PaLM)
- test: add unit tests for SwiGLU (shape, dtype, gradients, output range, fp16 support, reproducibility)
- strictly doc/tests, no logic or API changes

This improves transparency and reliability for gated FFN blocks in transformer architectures.
This commit is contained in:
Sergey Penkovsky
2025-10-16 15:09:09 +03:00
parent 64d33783e0
commit 516f9580fb
2 changed files with 93 additions and 25 deletions

View File

@@ -24,31 +24,55 @@ from .silu import SiLU
class SwiGLU(nn.Module): class SwiGLU(nn.Module):
""" """
SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM). SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral).
Реализация SwiGLU активационной функции. Назначение:
-----------
- Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком).
- Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока.
- Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral.
Состоит из трех линейных слоев и активации SiLU: Формула и математика:
1. Gate слой + SiLU активация ---------------------
2. Up слой (линейное преобразование) Пусть x — вход, then:
3. Element-wise multiplication gate и up
4. Down слой (линейная проекция)
Научная суть: SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d
- Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации.
- Дает надежную гладкую активацию, хорошо работает на больших масштабах.
- Статья: "GLU Variants Improve Transformer" (Shazeer, 2020).
Формула: Типовая реализация (как здесь, по LLAMA/Mistral):
SwiGLU(x) = SiLU(W_g·x) * (W_u·x) gate = SiLU(Linear_gate(x)) # фитчерный \"gate\"
где SiLU(x) = x*sigma(x) up = Linear_up(x) # пропускная ветка
mult = gate * up # поэлементное умножение (контроль информации)
out = Linear_down(mult) # финальная проекция
out = Dropout(out) # регуляризация
Почему это работает:
-------------------
- Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space.
- SiLU обеспечивает smooth градиенты (лучше для обучения LLM).
- В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU.
Параметры конструктора:
-----------------------
emb_size: int
Размерность входного (и выходного) признакового пространства.
dropout: float
Dropout после final linear (обычно около 0.1).
Пример использования:
---------------------
>>> block = SwiGLU(emb_size=512, dropout=0.1)
>>> x = torch.randn(8, 16, 512)
>>> y = block(x)
>>> print(y.shape) # torch.Size([8, 16, 512])
References:
-----------
- Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202
- PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1)
- LLaMA: https://arxiv.org/abs/2302.13971
- Mistral: https://arxiv.org/abs/2310.06825
- HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama
Args:
emb_size (int): размер входов/выходов
dropout (float): после выходной проекции
Пример:
>>> ff = SwiGLU(emb_size=512, dropout=0.1)
>>> y = ff(torch.randn(2,10,512))
""" """
def __init__(self, emb_size: int, dropout: float = 0.1): def __init__(self, emb_size: int, dropout: float = 0.1):
@@ -68,19 +92,24 @@ class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
""" """
Прямой проход через SwiGLU слой. Прямой проход через блок SwiGLU.
Args: Args:
x: Входной тензор формы [batch_size, seq_len, emb_size] -----
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size]
Returns: Returns:
Выходной тензор формы [batch_size, seq_len, emb_size] --------
torch.Tensor той же формы
Алгоритм: Алгоритм:
---------
1. gate = SiLU(linear_gate(x)) 1. gate = SiLU(linear_gate(x))
2. up = linear_up(x) 2. up = linear_up(x)
3. output = linear_down(gate up) 3. mult = gate * up # поэлементно
4. apply dropout 4. out = linear_down(mult)
5. out = dropout(out)
""" """
# Gate ветвь: линейное преобразование + активация # Gate ветвь: линейное преобразование + активация
gate_out = self._gate(x) # [batch, seq, 4*emb] gate_out = self._gate(x) # [batch, seq, 4*emb]

View File

@@ -0,0 +1,39 @@
import torch
import pytest
from llm.core.swi_glu import SwiGLU
def test_swiglu_shape_and_dtype():
swiglu = SwiGLU(emb_size=32, dropout=0.1)
x = torch.randn(4, 10, 32)
y = swiglu(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_swiglu_forward_range():
swiglu = SwiGLU(emb_size=16, dropout=0.0)
x = torch.randn(3, 7, 16)
y = swiglu(x)
assert y.abs().max() < 20
def test_swiglu_gradients():
swiglu = SwiGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 5, 8, requires_grad=True)
out = swiglu(x)
loss = out.pow(2).sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_swiglu_fp16():
swiglu = SwiGLU(emb_size=16, dropout=0.0).half()
x = torch.randn(1, 8, 16).half()
y = swiglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_swiglu_reproducibility():
swiglu = SwiGLU(emb_size=8, dropout=0.0)
x = torch.ones(2, 4, 8)
y1 = swiglu(x)
y2 = swiglu(x)
assert torch.allclose(y1, y2)