Files
llm-arch-research/llm/tests/models/test_gpt2.py
Sergey Penkovsky dc440a3938 test(gpt2): add unit tests for generation, cache behavior, and error conditions
- Covers forward pass with and without KV-cache
- Verifies correct sequence generation for greedy, top-k, and top-p sampling
- Adds ValueError test for exceeding max sequence length
- Uses small random toy config and minimal setup for fast test feedback

Motivation: Prevent regressions in decoding, sampling, and KV-cache logic in GPT2 implementation.
2025-10-15 14:36:32 +03:00

53 lines
1.5 KiB
Python

import torch
import pytest
from llm.models.gpt.gpt2 import GPT2
@pytest.fixture
def config():
return {
"vocab_size": 100,
"embed_dim": 32,
"num_heads": 4,
"num_layers": 2,
"max_position_embeddings": 16,
"dropout": 0.0,
}
@pytest.fixture
def model(config):
return GPT2(config)
def test_forward_basic(model):
x = torch.randint(0, 100, (2, 8))
logits, cache = model(x)
assert logits.shape == (2, 8, 100)
assert isinstance(cache, list)
assert len(cache) == model._decoders.__len__()
def test_forward_with_cache(model):
x = torch.randint(0, 100, (2, 4))
logits, cache = model(x, use_cache=True)
x2 = torch.randint(0, 100, (2, 1))
logits2, cache2 = model(x2, use_cache=True, cache=cache)
assert logits2.shape == (2, 1, 100)
assert isinstance(cache2, list)
def test_generate_and_shape(model):
x = torch.randint(0, 100, (1, 5))
result = model.generate(x, max_new_tokens=3, do_sample=False)
assert result.shape == (1, 8)
def test_forward_sequence_too_long(model, config):
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
with pytest.raises(ValueError):
model(x)
def test_generate_with_sampling_topk(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
assert out.shape == (1, 5)
def test_generate_with_sampling_topp(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
assert out.shape == (1, 5)