mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1 - Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification) - Adapted all usages in GPT model to use new decoder structure and handle tuple output - Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before - Improved clarity and future extensibility for autoregressive generation and benchmarking - No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
@@ -28,7 +28,7 @@ def test_gpt_model_creation():
|
||||
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids)
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
||||
print("✅ GPT model creation and forward pass test passed")
|
||||
@@ -222,7 +222,7 @@ def test_gpt_with_tokenizer():
|
||||
input_ids = torch.tensor([tokens])
|
||||
|
||||
with torch.no_grad():
|
||||
logits = model(input_ids)
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
assert logits.shape == (1, len(tokens), vocab_size)
|
||||
print("✅ GPT with tokenizer integration test passed")
|
||||
|
||||
Reference in New Issue
Block a user