Files
llm-arch-research/llm/tests/core/test_geglu.py

61 lines
1.6 KiB
Python
Raw Normal View History

import torch
import pytest
from llm.core.geglu import GeGLU
@pytest.fixture
def geglu():
return GeGLU(emb_size=16, dropout=0.1)
def test_forward_shape(geglu):
x = torch.randn(2, 5, 16)
y = geglu(x)
assert y.shape == x.shape
def test_forward_no_batch(geglu):
x = torch.randn(1, 16)
y = geglu(x.unsqueeze(0))
assert y.shape == (1, 1, 16)
@pytest.mark.skip(reason="float16 not supported without parameter casting")
def test_forward_dtype_fp16():
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 4, 8).half()
y = geglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_forward_no_dropout():
geglu = GeGLU(emb_size=4, dropout=0.0)
x = torch.randn(3, 2, 4)
y = geglu(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()
def test_gradient_flow(geglu):
x = torch.randn(3, 8, 16, requires_grad=True)
y = geglu(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_forward_repeatability():
torch.manual_seed(42)
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(3, 2, 8)
y1 = geglu(x)
torch.manual_seed(42)
geglu2 = GeGLU(emb_size=8, dropout=0.0)
x2 = torch.randn(3, 2, 8)
y2 = geglu2(x2)
assert torch.allclose(y1, y2, atol=1e-5)
def test_edge_small_large():
geglu = GeGLU(emb_size=2, dropout=0.0)
x = torch.randn(2, 2, 2)
y = geglu(x)
assert y.shape == x.shape
geglu = GeGLU(emb_size=256, dropout=0.0)
x = torch.randn(1, 1, 256)
y = geglu(x)
assert y.shape == x.shape