mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 13:00:54 +00:00
docs(core): add docstrings and unit tests for SwiGLU block
- docs: rewrite and expand docstrings for SwiGLU class and forward method (motivation, math, architecture, usage, references to LLaMA/Mistral/PaLM) - test: add unit tests for SwiGLU (shape, dtype, gradients, output range, fp16 support, reproducibility) - strictly doc/tests, no logic or API changes This improves transparency and reliability for gated FFN blocks in transformer architectures.
This commit is contained in:
@@ -24,31 +24,55 @@ from .silu import SiLU
|
||||
|
||||
class SwiGLU(nn.Module):
|
||||
"""
|
||||
SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM).
|
||||
SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral).
|
||||
|
||||
Реализация SwiGLU активационной функции.
|
||||
Назначение:
|
||||
-----------
|
||||
- Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком).
|
||||
- Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока.
|
||||
- Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral.
|
||||
|
||||
Состоит из трех линейных слоев и активации SiLU:
|
||||
1. Gate слой + SiLU активация
|
||||
2. Up слой (линейное преобразование)
|
||||
3. Element-wise multiplication gate и up
|
||||
4. Down слой (линейная проекция)
|
||||
Формула и математика:
|
||||
---------------------
|
||||
Пусть x — вход, then:
|
||||
|
||||
Научная суть:
|
||||
- Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации.
|
||||
- Дает надежную гладкую активацию, хорошо работает на больших масштабах.
|
||||
- Статья: "GLU Variants Improve Transformer" (Shazeer, 2020).
|
||||
SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d
|
||||
|
||||
Формула:
|
||||
SwiGLU(x) = SiLU(W_g·x) * (W_u·x)
|
||||
где SiLU(x) = x*sigma(x)
|
||||
Типовая реализация (как здесь, по LLAMA/Mistral):
|
||||
gate = SiLU(Linear_gate(x)) # фитчерный \"gate\"
|
||||
up = Linear_up(x) # пропускная ветка
|
||||
mult = gate * up # поэлементное умножение (контроль информации)
|
||||
out = Linear_down(mult) # финальная проекция
|
||||
out = Dropout(out) # регуляризация
|
||||
|
||||
Почему это работает:
|
||||
-------------------
|
||||
- Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space.
|
||||
- SiLU обеспечивает smooth градиенты (лучше для обучения LLM).
|
||||
- В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU.
|
||||
|
||||
Параметры конструктора:
|
||||
-----------------------
|
||||
emb_size: int
|
||||
Размерность входного (и выходного) признакового пространства.
|
||||
dropout: float
|
||||
Dropout после final linear (обычно около 0.1).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> block = SwiGLU(emb_size=512, dropout=0.1)
|
||||
>>> x = torch.randn(8, 16, 512)
|
||||
>>> y = block(x)
|
||||
>>> print(y.shape) # torch.Size([8, 16, 512])
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202
|
||||
- PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1)
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama
|
||||
|
||||
Args:
|
||||
emb_size (int): размер входов/выходов
|
||||
dropout (float): после выходной проекции
|
||||
Пример:
|
||||
>>> ff = SwiGLU(emb_size=512, dropout=0.1)
|
||||
>>> y = ff(torch.randn(2,10,512))
|
||||
"""
|
||||
|
||||
def __init__(self, emb_size: int, dropout: float = 0.1):
|
||||
@@ -68,19 +92,24 @@ class SwiGLU(nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через SwiGLU слой.
|
||||
Прямой проход через блок SwiGLU.
|
||||
|
||||
Args:
|
||||
x: Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор формы [batch_size, seq_len, emb_size]
|
||||
|
||||
Returns:
|
||||
Выходной тензор формы [batch_size, seq_len, emb_size]
|
||||
--------
|
||||
torch.Tensor той же формы
|
||||
|
||||
Алгоритм:
|
||||
---------
|
||||
1. gate = SiLU(linear_gate(x))
|
||||
2. up = linear_up(x)
|
||||
3. output = linear_down(gate ⊙ up)
|
||||
4. apply dropout
|
||||
3. mult = gate * up # поэлементно
|
||||
4. out = linear_down(mult)
|
||||
5. out = dropout(out)
|
||||
"""
|
||||
# Gate ветвь: линейное преобразование + активация
|
||||
gate_out = self._gate(x) # [batch, seq, 4*emb]
|
||||
|
||||
39
llm/tests/core/test_swi_glu.py
Normal file
39
llm/tests/core/test_swi_glu.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.swi_glu import SwiGLU
|
||||
|
||||
def test_swiglu_shape_and_dtype():
|
||||
swiglu = SwiGLU(emb_size=32, dropout=0.1)
|
||||
x = torch.randn(4, 10, 32)
|
||||
y = swiglu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_swiglu_forward_range():
|
||||
swiglu = SwiGLU(emb_size=16, dropout=0.0)
|
||||
x = torch.randn(3, 7, 16)
|
||||
y = swiglu(x)
|
||||
assert y.abs().max() < 20
|
||||
|
||||
def test_swiglu_gradients():
|
||||
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.randn(2, 5, 8, requires_grad=True)
|
||||
out = swiglu(x)
|
||||
loss = out.pow(2).sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_swiglu_fp16():
|
||||
swiglu = SwiGLU(emb_size=16, dropout=0.0).half()
|
||||
x = torch.randn(1, 8, 16).half()
|
||||
y = swiglu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == torch.float16
|
||||
|
||||
def test_swiglu_reproducibility():
|
||||
swiglu = SwiGLU(emb_size=8, dropout=0.0)
|
||||
x = torch.ones(2, 4, 8)
|
||||
y1 = swiglu(x)
|
||||
y2 = swiglu(x)
|
||||
assert torch.allclose(y1, y2)
|
||||
Reference in New Issue
Block a user