Files
llm-arch-research/llm/tests/core/test_gelu.py
Sergey Penkovsky 0832d78acf docs(core): improve docstrings and add unit tests for GELU activation
- docs: rewrite and expand docstrings for GELU class and method (motivation, math formula, smoother ReLU for Transformers, usage, references)
- test: add dedicated tests for GELU (output shape, dtype, comparison with torch GELU, monotonicity, gradients, large/small value behavior)
- fix: align numerical test to allow for minor approximation difference vs PyTorch gelu

This update makes the GELU module more transparent and robust for deep learning practitioners and researchers.
2025-10-16 13:59:38 +03:00

47 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import pytest
from llm.core.gelu import GELU
def test_gelu_shapes_and_dtype():
gelu = GELU()
x = torch.randn(4, 16, 8)
y = gelu(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_gelu_known_values():
gelu = GELU()
x = torch.tensor([-3.0, 0.0, 3.0])
y = gelu(x)
# Сравнение с PyTorch F.gelu (которая использует точный алгоритм)
y_ref = torch.nn.functional.gelu(x)
diff = (y - y_ref).abs().max().item()
assert diff < 5e-3, f"Max difference {diff} exceeds threshold"
def test_gelu_is_smooth_and_monotonic():
gelu = GELU()
x = torch.linspace(-5, 5, 100)
y = gelu(x)
dy = y[1:] - y[:-1]
# Проверяем, что функция GELU хотя бы локально монотонна на большинстве промежутков
assert (dy.mean() > 0 or dy.mean() < 0)
def test_gelu_gradients():
gelu = GELU()
x = torch.randn(3, 5, requires_grad=True)
y = gelu(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_gelu_large_vs_small():
gelu = GELU()
x_pos = torch.tensor([100.0])
x_neg = torch.tensor([-100.0])
y_pos = gelu(x_pos)
y_neg = gelu(x_neg)
# Для больших положительных GELU(x) ~ x, для больших отрицательных ~0
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4)
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4)