mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1 - Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification) - Adapted all usages in GPT model to use new decoder structure and handle tuple output - Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before - Improved clarity and future extensibility for autoregressive generation and benchmarking - No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
@@ -30,7 +30,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Check output shape
|
||||
batch_size, seq_len = random_inputs.shape
|
||||
@@ -45,7 +45,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass with mask
|
||||
logits = model(random_inputs, attention_mask=attention_mask)
|
||||
logits, _ = model(random_inputs, attention_mask=attention_mask)
|
||||
|
||||
# Check output shape
|
||||
batch_size, seq_len = random_inputs.shape
|
||||
@@ -132,7 +132,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
# Forward pass
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Create a dummy loss and backward pass
|
||||
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
|
||||
@@ -157,7 +157,7 @@ class TestGPT:
|
||||
inputs = random_inputs.to(device)
|
||||
|
||||
# Forward pass
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
# Check device consistency
|
||||
assert logits.device == device
|
||||
@@ -197,7 +197,7 @@ class TestGPT:
|
||||
batch_size, seq_len = 2, 16
|
||||
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
expected_shape = (batch_size, seq_len, config["vocab_size"])
|
||||
assert logits.shape == expected_shape
|
||||
@@ -208,7 +208,7 @@ class TestGPT:
|
||||
model = GPT(gpt_config)
|
||||
|
||||
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
|
||||
logits = model(inputs)
|
||||
logits, _ = model(inputs)
|
||||
|
||||
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
|
||||
assert logits.shape == expected_shape
|
||||
@@ -219,11 +219,11 @@ class TestGPT:
|
||||
|
||||
# Training mode
|
||||
model.train()
|
||||
output_train = model(random_inputs)
|
||||
output_train, _ = model(random_inputs)
|
||||
|
||||
# Evaluation mode
|
||||
model.eval()
|
||||
output_eval = model(random_inputs)
|
||||
output_eval, _ = model(random_inputs)
|
||||
|
||||
# Outputs should be different due to dropout
|
||||
assert not torch.allclose(output_train, output_eval)
|
||||
@@ -271,7 +271,7 @@ class TestGPT:
|
||||
"""Test that GPT output has proper distribution."""
|
||||
model = GPT(gpt_config)
|
||||
|
||||
logits = model(random_inputs)
|
||||
logits, _ = model(random_inputs)
|
||||
|
||||
# Logits should not have extreme values
|
||||
assert logits.abs().max() < 100
|
||||
|
||||
Reference in New Issue
Block a user