refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests

- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1
- Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification)
- Adapted all usages in GPT model to use new decoder structure and handle tuple output
- Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before
- Improved clarity and future extensibility for autoregressive generation and benchmarking
- No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
Sergey Penkovsky
2025-10-22 16:27:08 +03:00
parent ddc4924a37
commit 25caf69ced
5 changed files with 113 additions and 60 deletions

View File

@@ -4,17 +4,17 @@ Tests for decoder block.
import pytest
import torch
from llm.core.decoder import Decoder
from llm.core.gpt_decoder import GptDecoder
class TestDecoder:
class TestGptDecoder:
"""Test cases for Decoder."""
def test_initialization(self, embed_dim, num_heads):
"""Test that Decoder can be initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -32,7 +32,7 @@ class TestDecoder:
"""Test forward pass of Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -40,7 +40,7 @@ class TestDecoder:
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -50,7 +50,7 @@ class TestDecoder:
"""Test forward pass with causal mask."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -62,7 +62,7 @@ class TestDecoder:
mask = torch.tril(torch.ones(seq_len, seq_len))
# Forward pass with causal mask
output = decoder(random_embeddings, mask=mask)
output, _ = decoder(random_embeddings, attention_mask=mask)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -71,14 +71,14 @@ class TestDecoder:
"""Test that residual connections are properly applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# With residual connections and layer norm, the output shouldn't be
# too different from input (in terms of scale/distribution)
@@ -92,14 +92,14 @@ class TestDecoder:
"""Test that layer normalization is applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check that output has reasonable statistics (due to layer norm)
# Mean should be close to 0, std close to 1 for each sequence position
@@ -114,7 +114,7 @@ class TestDecoder:
"""Test that gradients flow through Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -122,7 +122,7 @@ class TestDecoder:
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Create a dummy loss and backward pass
loss = output.sum()
@@ -139,7 +139,7 @@ class TestDecoder:
"""Test that Decoder works on correct device."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -148,7 +148,7 @@ class TestDecoder:
inputs = random_embeddings.to(device)
# Forward pass
output = decoder(inputs)
output, _ = decoder(inputs)
# Check device consistency
assert output.device == device
@@ -165,7 +165,7 @@ class TestDecoder:
for embed_dim, num_heads in test_cases:
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -174,7 +174,7 @@ class TestDecoder:
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == inputs.shape
@@ -183,7 +183,7 @@ class TestDecoder:
"""Test Decoder with different input shapes."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -191,7 +191,7 @@ class TestDecoder:
)
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == (batch_size, seq_len, embed_dim)
@@ -199,7 +199,7 @@ class TestDecoder:
"""Test that Decoder behaves differently in train vs eval mode."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -209,11 +209,11 @@ class TestDecoder:
# Training mode
decoder.train()
output_train = decoder(random_embeddings)
output_train, _ = decoder(random_embeddings)
# Evaluation mode
decoder.eval()
output_eval = decoder(random_embeddings)
output_eval, _ = decoder(random_embeddings)
# Outputs should be different due to dropout
assert not torch.allclose(output_train, output_eval)
@@ -222,7 +222,7 @@ class TestDecoder:
"""Test that parameters are properly initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,