mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1 - Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification) - Adapted all usages in GPT model to use new decoder structure and handle tuple output - Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before - Improved clarity and future extensibility for autoregressive generation and benchmarking - No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
@@ -4,17 +4,17 @@ Tests for decoder block.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from llm.core.decoder import Decoder
|
||||
from llm.core.gpt_decoder import GptDecoder
|
||||
|
||||
|
||||
class TestDecoder:
|
||||
class TestGptDecoder:
|
||||
"""Test cases for Decoder."""
|
||||
|
||||
def test_initialization(self, embed_dim, num_heads):
|
||||
"""Test that Decoder can be initialized."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -32,7 +32,7 @@ class TestDecoder:
|
||||
"""Test forward pass of Decoder."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -40,7 +40,7 @@ class TestDecoder:
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output = decoder(random_embeddings)
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -50,7 +50,7 @@ class TestDecoder:
|
||||
"""Test forward pass with causal mask."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -62,7 +62,7 @@ class TestDecoder:
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||
|
||||
# Forward pass with causal mask
|
||||
output = decoder(random_embeddings, mask=mask)
|
||||
output, _ = decoder(random_embeddings, attention_mask=mask)
|
||||
|
||||
# Check output shape
|
||||
assert output.shape == random_embeddings.shape
|
||||
@@ -71,14 +71,14 @@ class TestDecoder:
|
||||
"""Test that residual connections are properly applied."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
output = decoder(random_embeddings)
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# With residual connections and layer norm, the output shouldn't be
|
||||
# too different from input (in terms of scale/distribution)
|
||||
@@ -92,14 +92,14 @@ class TestDecoder:
|
||||
"""Test that layer normalization is applied."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
output = decoder(random_embeddings)
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Check that output has reasonable statistics (due to layer norm)
|
||||
# Mean should be close to 0, std close to 1 for each sequence position
|
||||
@@ -114,7 +114,7 @@ class TestDecoder:
|
||||
"""Test that gradients flow through Decoder."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -122,7 +122,7 @@ class TestDecoder:
|
||||
)
|
||||
|
||||
# Forward pass
|
||||
output = decoder(random_embeddings)
|
||||
output, _ = decoder(random_embeddings)
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
loss = output.sum()
|
||||
@@ -139,7 +139,7 @@ class TestDecoder:
|
||||
"""Test that Decoder works on correct device."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -148,7 +148,7 @@ class TestDecoder:
|
||||
inputs = random_embeddings.to(device)
|
||||
|
||||
# Forward pass
|
||||
output = decoder(inputs)
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
# Check device consistency
|
||||
assert output.device == device
|
||||
@@ -165,7 +165,7 @@ class TestDecoder:
|
||||
for embed_dim, num_heads in test_cases:
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -174,7 +174,7 @@ class TestDecoder:
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
|
||||
output = decoder(inputs)
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
assert output.shape == inputs.shape
|
||||
|
||||
@@ -183,7 +183,7 @@ class TestDecoder:
|
||||
"""Test Decoder with different input shapes."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -191,7 +191,7 @@ class TestDecoder:
|
||||
)
|
||||
|
||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||
output = decoder(inputs)
|
||||
output, _ = decoder(inputs)
|
||||
|
||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||
|
||||
@@ -199,7 +199,7 @@ class TestDecoder:
|
||||
"""Test that Decoder behaves differently in train vs eval mode."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -209,11 +209,11 @@ class TestDecoder:
|
||||
|
||||
# Training mode
|
||||
decoder.train()
|
||||
output_train = decoder(random_embeddings)
|
||||
output_train, _ = decoder(random_embeddings)
|
||||
|
||||
# Evaluation mode
|
||||
decoder.eval()
|
||||
output_eval = decoder(random_embeddings)
|
||||
output_eval, _ = decoder(random_embeddings)
|
||||
|
||||
# Outputs should be different due to dropout
|
||||
assert not torch.allclose(output_train, output_eval)
|
||||
@@ -222,7 +222,7 @@ class TestDecoder:
|
||||
"""Test that parameters are properly initialized."""
|
||||
head_size = embed_dim // num_heads
|
||||
max_seq_len = 1024
|
||||
decoder = Decoder(
|
||||
decoder = GptDecoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=embed_dim,
|
||||
head_size=head_size,
|
||||
@@ -30,7 +30,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Check output shape
|
||||
batch_size, seq_len = random_inputs.shape
|
||||
@@ -45,7 +45,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass with mask
|
||||
logits = model(random_inputs, attention_mask=attention_mask)
|
||||
logits, _ = model(random_inputs, attention_mask=attention_mask)
|
||||
|
||||
# Check output shape
|
||||
batch_size, seq_len = random_inputs.shape
|
||||
@@ -132,7 +132,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
|
||||
@@ -157,7 +157,7 @@ class TestGPT:
|
||||
inputs = random_inputs.to(device)
|
||||
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
# Check device consistency
|
||||
assert logits.device == device
|
||||
@@ -197,7 +197,7 @@ class TestGPT:
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
expected_shape = (batch_size, seq_len, config["vocab_size"])
|
||||
assert logits.shape == expected_shape
|
||||
@@ -208,7 +208,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
|
||||
assert logits.shape == expected_shape
|
||||
@@ -219,11 +219,11 @@ class TestGPT:
|
||||
|
||||
# Training mode
|
||||
model.train()
|
||||
output_train = model(random_inputs)
|
||||
output_train, _ = model(random_inputs)
|
||||
|
||||
# Evaluation mode
|
||||
model.eval()
|
||||
output_eval = model(random_inputs)
|
||||
output_eval, _ = model(random_inputs)
|
||||
|
||||
# Outputs should be different due to dropout
|
||||
assert not torch.allclose(output_train, output_eval)
|
||||
@@ -271,7 +271,7 @@ class TestGPT:
|
||||
"""Test that GPT output has proper distribution."""
|
||||
model = GPT(gpt_config)
|
||||
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Logits should not have extreme values
|
||||
assert logits.abs().max() < 100
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_gpt_model_creation():
|
||||
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids)
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
||||
print("✅ GPT model creation and forward pass test passed")
|
||||
@@ -222,7 +222,7 @@ def test_gpt_with_tokenizer():
|
||||
input_ids = torch.tensor([tokens])
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids)
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
assert logits.shape == (1, len(tokens), vocab_size)
|
||||
print("✅ GPT with tokenizer integration test passed")
|
||||
|
||||
Reference in New Issue
Block a user