refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests

- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1
- Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification)
- Adapted all usages in GPT model to use new decoder structure and handle tuple output
- Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before
- Improved clarity and future extensibility for autoregressive generation and benchmarking
- No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
Sergey Penkovsky
2025-10-22 16:27:08 +03:00
parent ddc4924a37
commit 25caf69ced
5 changed files with 113 additions and 60 deletions

View File

@@ -30,7 +30,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Check output shape
batch_size, seq_len = random_inputs.shape
@@ -45,7 +45,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass with mask
logits = model(random_inputs, attention_mask=attention_mask)
logits, _ = model(random_inputs, attention_mask=attention_mask)
# Check output shape
batch_size, seq_len = random_inputs.shape
@@ -132,7 +132,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Create a dummy loss and backward pass
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
@@ -157,7 +157,7 @@ class TestGPT:
inputs = random_inputs.to(device)
# Forward pass
logits = model(inputs)
logits, _ = model(inputs)
# Check device consistency
assert logits.device == device
@@ -197,7 +197,7 @@ class TestGPT:
batch_size, seq_len = 2, 16
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
logits = model(inputs)
logits, _ = model(inputs)
expected_shape = (batch_size, seq_len, config["vocab_size"])
assert logits.shape == expected_shape
@@ -208,7 +208,7 @@ class TestGPT:
model = GPT(gpt_config)
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
logits = model(inputs)
logits, _ = model(inputs)
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
assert logits.shape == expected_shape
@@ -219,11 +219,11 @@ class TestGPT:
# Training mode
model.train()
output_train = model(random_inputs)
output_train, _ = model(random_inputs)
# Evaluation mode
model.eval()
output_eval = model(random_inputs)
output_eval, _ = model(random_inputs)
# Outputs should be different due to dropout
assert not torch.allclose(output_train, output_eval)
@@ -271,7 +271,7 @@ class TestGPT:
"""Test that GPT output has proper distribution."""
model = GPT(gpt_config)
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Logits should not have extreme values
assert logits.abs().max() < 100