mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(core): add docstrings and unit tests for SiLU activation
- docs: expand and clarify docstrings for SiLU class and its method (mathematical formula, motivation, properties vs ReLU/GELU, usage, and references to Swish/LLM papers) - test: add unit tests for SiLU (shape/dtype, behavior on large/small values, PyTorch reference, gradients, broadcast) - no logic/API changes This update improves reliability and usability of the SiLU activation module.
This commit is contained in:
@@ -4,18 +4,67 @@ from torch import nn
|
||||
|
||||
class SiLU(nn.Module):
|
||||
"""
|
||||
SiLU (Swish) — современная активационная функция для нейросетей.
|
||||
SiLU (Sigmoid Linear Unit, также известная как Swish) — современная функция активации для нейросетей и LLM.
|
||||
|
||||
Научная суть:
|
||||
- Формула: $SiLU(x) = x * \sigm(x)$, где $\sigm(x)$ — сигмоида.
|
||||
- Более гладкая альтернатива ReLU, улучшает поток градиентов в глубоких сетях.
|
||||
- Используется во многих «state-of-the-art» архитектурах (SwiGLU, PaLM, LLaMA).
|
||||
- Также известна как Swish (Ramachandran et al, 2017).
|
||||
Пример:
|
||||
>>> act = SiLU()
|
||||
>>> x = torch.tensor([-1.0, 0.0, 1.0])
|
||||
>>> print(act(x))
|
||||
Назначение:
|
||||
-----------
|
||||
- Формирует плавную нелинейную активацию: SiLU(x) = x * sigmoid(x).
|
||||
- Активно используется во всех новых архитектурах для больших языковых моделей (PaLM, LLaMA, Mistral, GPT-4 и др.).
|
||||
- Дает лучший поток градиентов по сравнению с ReLU, SELU, GELU в глубоких слоях — позволяет делать сети больше и глубже.
|
||||
|
||||
Мотивация и свойства:
|
||||
---------------------
|
||||
- SiLU объединяет свойства identity (для больших x) и ReLU (для отрицательных x, где есть затухание), но более плавно.
|
||||
- Позволяет проходить отрицательным значениям, а не "обрубает" как ReLU.
|
||||
- Better for optimization and training dynamics in deep LLMs, приводит к более богатым аппроксимациям.
|
||||
|
||||
Математическая формула:
|
||||
-----------------------
|
||||
SiLU(x) = x * sigmoid(x)
|
||||
где sigmoid(x) = 1 / (1 + exp(-x))
|
||||
|
||||
Сравнение с другими активациями:
|
||||
--------------------------------
|
||||
- ReLU(x): max(0, x) — простая отсечка
|
||||
- GELU(x): плавная вероятностная активация (используется в BERT/GPT-2)
|
||||
- SiLU(x): плавная альтернатива, часто лучше в современных LLM
|
||||
- Swish (Ramachandran et al., 2017) = SiLU
|
||||
|
||||
Args:
|
||||
-----
|
||||
Нет learnable параметров, чисто функциональная активация.
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> silu = SiLU()
|
||||
>>> x = torch.tensor([-2.0, 0.0, 2.0])
|
||||
>>> print(silu(x)) # тензор с элементами [-0.2384, 0.0, 1.7616] (примерно)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Ramachandran et al., "Searching for Activation Functions", 2017: https://arxiv.org/abs/1710.05941
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
- Swish в TensorFlow: https://arxiv.org/abs/1710.05941
|
||||
- Сравнение всех актив. функций: https://paperswithcode.com/method/silu
|
||||
"""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
Применяет SiLU активацию ко всем компонентам тензора (x * sigmoid(x)).
|
||||
|
||||
Args:
|
||||
-----
|
||||
x : torch.Tensor
|
||||
Входной тензор любой формы.
|
||||
|
||||
Returns:
|
||||
--------
|
||||
torch.Tensor — тензор той же формы, каждый элемент преобразован по формуле SiLU(x).
|
||||
|
||||
Пример:
|
||||
-------
|
||||
>>> silu = SiLU()
|
||||
>>> x = torch.linspace(-3, 3, 7)
|
||||
>>> y = silu(x)
|
||||
"""
|
||||
return torch.sigmoid(x) * x
|
||||
|
||||
42
llm/tests/core/test_silu.py
Normal file
42
llm/tests/core/test_silu.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.silu import SiLU
|
||||
|
||||
def test_silu_shape_and_dtype():
|
||||
silu = SiLU()
|
||||
x = torch.randn(3, 10, 8)
|
||||
y = silu(x)
|
||||
assert y.shape == x.shape
|
||||
assert y.dtype == x.dtype
|
||||
|
||||
def test_silu_known_values():
|
||||
silu = SiLU()
|
||||
x = torch.tensor([-2.0, 0.0, 2.0])
|
||||
y = silu(x)
|
||||
# PyTorch эталон
|
||||
y_ref = torch.nn.functional.silu(x)
|
||||
assert torch.allclose(y, y_ref, atol=1e-6)
|
||||
|
||||
def test_silu_large_vs_small():
|
||||
silu = SiLU()
|
||||
x_pos = torch.tensor([100.0])
|
||||
x_neg = torch.tensor([-100.0])
|
||||
y_pos = silu(x_pos)
|
||||
y_neg = silu(x_neg)
|
||||
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4) # SiLU(x) ~ x для больших x>0
|
||||
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4) # SiLU(x) ~ 0 для x<0
|
||||
|
||||
def test_silu_gradients():
|
||||
silu = SiLU()
|
||||
x = torch.randn(4, 4, requires_grad=True)
|
||||
y = silu(x)
|
||||
loss = y.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.shape == x.shape
|
||||
|
||||
def test_silu_broadcast():
|
||||
silu = SiLU()
|
||||
x = torch.randn(3, 1, 16)
|
||||
y = silu(x)
|
||||
assert y.shape == x.shape
|
||||
Reference in New Issue
Block a user