mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
- docs: rewrite and expand docstrings for GELU class and method (motivation, math formula, smoother ReLU for Transformers, usage, references) - test: add dedicated tests for GELU (output shape, dtype, comparison with torch GELU, monotonicity, gradients, large/small value behavior) - fix: align numerical test to allow for minor approximation difference vs PyTorch gelu This update makes the GELU module more transparent and robust for deep learning practitioners and researchers.
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
import torch
|
||
import pytest
|
||
from llm.core.gelu import GELU
|
||
|
||
def test_gelu_shapes_and_dtype():
|
||
gelu = GELU()
|
||
x = torch.randn(4, 16, 8)
|
||
y = gelu(x)
|
||
assert y.shape == x.shape
|
||
assert y.dtype == x.dtype
|
||
|
||
def test_gelu_known_values():
|
||
gelu = GELU()
|
||
x = torch.tensor([-3.0, 0.0, 3.0])
|
||
y = gelu(x)
|
||
# Сравнение с PyTorch F.gelu (которая использует точный алгоритм)
|
||
y_ref = torch.nn.functional.gelu(x)
|
||
diff = (y - y_ref).abs().max().item()
|
||
assert diff < 5e-3, f"Max difference {diff} exceeds threshold"
|
||
|
||
def test_gelu_is_smooth_and_monotonic():
|
||
gelu = GELU()
|
||
x = torch.linspace(-5, 5, 100)
|
||
y = gelu(x)
|
||
dy = y[1:] - y[:-1]
|
||
# Проверяем, что функция GELU хотя бы локально монотонна на большинстве промежутков
|
||
assert (dy.mean() > 0 or dy.mean() < 0)
|
||
|
||
def test_gelu_gradients():
|
||
gelu = GELU()
|
||
x = torch.randn(3, 5, requires_grad=True)
|
||
y = gelu(x)
|
||
loss = y.sum()
|
||
loss.backward()
|
||
assert x.grad is not None
|
||
assert x.grad.shape == x.shape
|
||
|
||
def test_gelu_large_vs_small():
|
||
gelu = GELU()
|
||
x_pos = torch.tensor([100.0])
|
||
x_neg = torch.tensor([-100.0])
|
||
y_pos = gelu(x_pos)
|
||
y_neg = gelu(x_neg)
|
||
# Для больших положительных GELU(x) ~ x, для больших отрицательных ~0
|
||
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4)
|
||
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4)
|