diff --git a/llm/src/llm/core/silu.py b/llm/src/llm/core/silu.py index ade010a..9604ff3 100644 --- a/llm/src/llm/core/silu.py +++ b/llm/src/llm/core/silu.py @@ -4,18 +4,67 @@ from torch import nn class SiLU(nn.Module): """ - SiLU (Swish) — современная активационная функция для нейросетей. + SiLU (Sigmoid Linear Unit, также известная как Swish) — современная функция активации для нейросетей и LLM. - Научная суть: - - Формула: $SiLU(x) = x * \sigm(x)$, где $\sigm(x)$ — сигмоида. - - Более гладкая альтернатива ReLU, улучшает поток градиентов в глубоких сетях. - - Используется во многих «state-of-the-art» архитектурах (SwiGLU, PaLM, LLaMA). - - Также известна как Swish (Ramachandran et al, 2017). - Пример: - >>> act = SiLU() - >>> x = torch.tensor([-1.0, 0.0, 1.0]) - >>> print(act(x)) + Назначение: + ----------- + - Формирует плавную нелинейную активацию: SiLU(x) = x * sigmoid(x). + - Активно используется во всех новых архитектурах для больших языковых моделей (PaLM, LLaMA, Mistral, GPT-4 и др.). + - Дает лучший поток градиентов по сравнению с ReLU, SELU, GELU в глубоких слоях — позволяет делать сети больше и глубже. + + Мотивация и свойства: + --------------------- + - SiLU объединяет свойства identity (для больших x) и ReLU (для отрицательных x, где есть затухание), но более плавно. + - Позволяет проходить отрицательным значениям, а не "обрубает" как ReLU. + - Better for optimization and training dynamics in deep LLMs, приводит к более богатым аппроксимациям. + + Математическая формула: + ----------------------- + SiLU(x) = x * sigmoid(x) + где sigmoid(x) = 1 / (1 + exp(-x)) + + Сравнение с другими активациями: + -------------------------------- + - ReLU(x): max(0, x) — простая отсечка + - GELU(x): плавная вероятностная активация (используется в BERT/GPT-2) + - SiLU(x): плавная альтернатива, часто лучше в современных LLM + - Swish (Ramachandran et al., 2017) = SiLU + + Args: + ----- + Нет learnable параметров, чисто функциональная активация. + + Пример использования: + --------------------- + >>> silu = SiLU() + >>> x = torch.tensor([-2.0, 0.0, 2.0]) + >>> print(silu(x)) # тензор с элементами [-0.2384, 0.0, 1.7616] (примерно) + + References: + ----------- + - Ramachandran et al., "Searching for Activation Functions", 2017: https://arxiv.org/abs/1710.05941 + - LLaMA: https://arxiv.org/abs/2302.13971 + - Swish в TensorFlow: https://arxiv.org/abs/1710.05941 + - Сравнение всех актив. функций: https://paperswithcode.com/method/silu """ def forward(self, x: torch.Tensor): + """ + Применяет SiLU активацию ко всем компонентам тензора (x * sigmoid(x)). + + Args: + ----- + x : torch.Tensor + Входной тензор любой формы. + + Returns: + -------- + torch.Tensor — тензор той же формы, каждый элемент преобразован по формуле SiLU(x). + + Пример: + ------- + >>> silu = SiLU() + >>> x = torch.linspace(-3, 3, 7) + >>> y = silu(x) + """ return torch.sigmoid(x) * x diff --git a/llm/tests/core/test_silu.py b/llm/tests/core/test_silu.py new file mode 100644 index 0000000..c452def --- /dev/null +++ b/llm/tests/core/test_silu.py @@ -0,0 +1,42 @@ +import torch +import pytest +from llm.core.silu import SiLU + +def test_silu_shape_and_dtype(): + silu = SiLU() + x = torch.randn(3, 10, 8) + y = silu(x) + assert y.shape == x.shape + assert y.dtype == x.dtype + +def test_silu_known_values(): + silu = SiLU() + x = torch.tensor([-2.0, 0.0, 2.0]) + y = silu(x) + # PyTorch эталон + y_ref = torch.nn.functional.silu(x) + assert torch.allclose(y, y_ref, atol=1e-6) + +def test_silu_large_vs_small(): + silu = SiLU() + x_pos = torch.tensor([100.0]) + x_neg = torch.tensor([-100.0]) + y_pos = silu(x_pos) + y_neg = silu(x_neg) + assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4) # SiLU(x) ~ x для больших x>0 + assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4) # SiLU(x) ~ 0 для x<0 + +def test_silu_gradients(): + silu = SiLU() + x = torch.randn(4, 4, requires_grad=True) + y = silu(x) + loss = y.sum() + loss.backward() + assert x.grad is not None + assert x.grad.shape == x.shape + +def test_silu_broadcast(): + silu = SiLU() + x = torch.randn(3, 1, 16) + y = silu(x) + assert y.shape == x.shape