mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 13:32:08 +00:00
61 lines
1.6 KiB
Python
61 lines
1.6 KiB
Python
|
|
import torch
|
||
|
|
import pytest
|
||
|
|
from llm.core.geglu import GeGLU
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def geglu():
|
||
|
|
return GeGLU(emb_size=16, dropout=0.1)
|
||
|
|
|
||
|
|
def test_forward_shape(geglu):
|
||
|
|
x = torch.randn(2, 5, 16)
|
||
|
|
y = geglu(x)
|
||
|
|
assert y.shape == x.shape
|
||
|
|
|
||
|
|
def test_forward_no_batch(geglu):
|
||
|
|
x = torch.randn(1, 16)
|
||
|
|
y = geglu(x.unsqueeze(0))
|
||
|
|
assert y.shape == (1, 1, 16)
|
||
|
|
|
||
|
|
@pytest.mark.skip(reason="float16 not supported without parameter casting")
|
||
|
|
def test_forward_dtype_fp16():
|
||
|
|
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||
|
|
x = torch.randn(2, 4, 8).half()
|
||
|
|
y = geglu(x)
|
||
|
|
assert y.shape == x.shape
|
||
|
|
assert y.dtype == torch.float16
|
||
|
|
|
||
|
|
def test_forward_no_dropout():
|
||
|
|
geglu = GeGLU(emb_size=4, dropout=0.0)
|
||
|
|
x = torch.randn(3, 2, 4)
|
||
|
|
y = geglu(x)
|
||
|
|
assert not torch.isnan(y).any()
|
||
|
|
assert not torch.isinf(y).any()
|
||
|
|
|
||
|
|
def test_gradient_flow(geglu):
|
||
|
|
x = torch.randn(3, 8, 16, requires_grad=True)
|
||
|
|
y = geglu(x)
|
||
|
|
y.sum().backward()
|
||
|
|
assert x.grad is not None
|
||
|
|
assert x.grad.shape == x.shape
|
||
|
|
|
||
|
|
def test_forward_repeatability():
|
||
|
|
torch.manual_seed(42)
|
||
|
|
geglu = GeGLU(emb_size=8, dropout=0.0)
|
||
|
|
x = torch.randn(3, 2, 8)
|
||
|
|
y1 = geglu(x)
|
||
|
|
torch.manual_seed(42)
|
||
|
|
geglu2 = GeGLU(emb_size=8, dropout=0.0)
|
||
|
|
x2 = torch.randn(3, 2, 8)
|
||
|
|
y2 = geglu2(x2)
|
||
|
|
assert torch.allclose(y1, y2, atol=1e-5)
|
||
|
|
|
||
|
|
def test_edge_small_large():
|
||
|
|
geglu = GeGLU(emb_size=2, dropout=0.0)
|
||
|
|
x = torch.randn(2, 2, 2)
|
||
|
|
y = geglu(x)
|
||
|
|
assert y.shape == x.shape
|
||
|
|
geglu = GeGLU(emb_size=256, dropout=0.0)
|
||
|
|
x = torch.randn(1, 1, 256)
|
||
|
|
y = geglu(x)
|
||
|
|
assert y.shape == x.shape
|