mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1 - Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification) - Adapted all usages in GPT model to use new decoder structure and handle tuple output - Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before - Improved clarity and future extensibility for autoregressive generation and benchmarking - No changes to architectural details or training loop; pure API and test modernization
This commit is contained in:
@@ -4,7 +4,7 @@ from .feed_forward import FeedForward
|
|||||||
from .multi_head_attention import MultiHeadAttention
|
from .multi_head_attention import MultiHeadAttention
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class GptDecoder(nn.Module):
|
||||||
"""
|
"""
|
||||||
Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
|
Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
|
||||||
|
|
||||||
@@ -94,7 +94,13 @@ class Decoder(nn.Module):
|
|||||||
self._norm1 = nn.LayerNorm(emb_size)
|
self._norm1 = nn.LayerNorm(emb_size)
|
||||||
self._norm2 = nn.LayerNorm(emb_size)
|
self._norm2 = nn.LayerNorm(emb_size)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cache: list = None,
|
||||||
|
attention_mask=None
|
||||||
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Один прямой проход через Transformer decoder block.
|
Один прямой проход через Transformer decoder block.
|
||||||
|
|
||||||
@@ -117,10 +123,16 @@ class Decoder(nn.Module):
|
|||||||
- Применяем FFN к нормализованному результату (layernorm)
|
- Применяем FFN к нормализованному результату (layernorm)
|
||||||
- Добавляем residual-связь (ffn + предыдущий выход)
|
- Добавляем residual-связь (ffn + предыдущий выход)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Self-Attention блок
|
# Self-Attention блок
|
||||||
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
|
attention, kv_caches = self._heads(x, attention_mask, use_cache=use_cache, cache=cache)
|
||||||
out = self._norm1(attention + x)
|
out = self._norm1(attention + x)
|
||||||
|
|
||||||
# FeedForward блок
|
# FeedForward блок
|
||||||
ffn_out = self._ff(out)
|
ffn_out = self._ff(out)
|
||||||
return self._norm2(ffn_out + out)
|
result = self._norm2(ffn_out + out)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
return (result, kv_caches)
|
||||||
|
else:
|
||||||
|
return (result, None)
|
||||||
@@ -26,7 +26,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Dict
|
from typing import Optional, Dict
|
||||||
from llm.core.base_model import BaseModel
|
from llm.core.base_model import BaseModel
|
||||||
from llm.core.decoder import Decoder
|
from llm.core.gpt_decoder import GptDecoder
|
||||||
from llm.core.token_embeddings import TokenEmbeddings
|
from llm.core.token_embeddings import TokenEmbeddings
|
||||||
from llm.core.positional_embeddings import PositionalEmbeddings
|
from llm.core.positional_embeddings import PositionalEmbeddings
|
||||||
|
|
||||||
@@ -116,7 +116,7 @@ class GPT(BaseModel):
|
|||||||
# head_size = emb_size // num_heads
|
# head_size = emb_size // num_heads
|
||||||
self._decoders = nn.ModuleList(
|
self._decoders = nn.ModuleList(
|
||||||
[
|
[
|
||||||
Decoder(
|
GptDecoder(
|
||||||
num_heads=config["num_heads"],
|
num_heads=config["num_heads"],
|
||||||
emb_size=config["embed_dim"],
|
emb_size=config["embed_dim"],
|
||||||
head_size=config["embed_dim"] // config["num_heads"],
|
head_size=config["embed_dim"] // config["num_heads"],
|
||||||
@@ -133,7 +133,9 @@ class GPT(BaseModel):
|
|||||||
"""Возвращает максимальную длину последовательности."""
|
"""Возвращает максимальную длину последовательности."""
|
||||||
return self._max_seq_len
|
return self._max_seq_len
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
|
def forward(
|
||||||
|
self, x: torch.Tensor, attention_mask=None, use_cache: bool = True, cache: list = None
|
||||||
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Прямой проход для получения логитов по последовательности токенов.
|
Прямой проход для получения логитов по последовательности токенов.
|
||||||
|
|
||||||
@@ -157,33 +159,60 @@ class GPT(BaseModel):
|
|||||||
f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}"
|
f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Вычисление start_pos из кэша (если кэш передан)
|
||||||
|
if cache is not None:
|
||||||
|
seq_len = 1
|
||||||
|
# Безопасно извлекаем key_cache для вычисления start_pos
|
||||||
|
if (
|
||||||
|
isinstance(cache, (list, tuple))
|
||||||
|
and len(cache) > 0
|
||||||
|
and cache[0] is not None
|
||||||
|
and isinstance(cache[0], (list, tuple))
|
||||||
|
and len(cache[0]) > 0
|
||||||
|
and cache[0][0] is not None
|
||||||
|
and isinstance(cache[0][0], (tuple, list))
|
||||||
|
and len(cache[0][0]) > 0
|
||||||
|
):
|
||||||
|
key_cache, _ = cache[0][0]
|
||||||
|
start_pos = key_cache.size(1)
|
||||||
|
else:
|
||||||
|
start_pos = 0
|
||||||
|
else:
|
||||||
|
# Без кэша работаем как раньше
|
||||||
|
start_pos = 0
|
||||||
|
seq_len = x.size(1)
|
||||||
|
|
||||||
# Эмбеддинги токенов и позиций
|
# Эмбеддинги токенов и позиций
|
||||||
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
|
||||||
pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size]
|
pos_out = self._position_embeddings(
|
||||||
|
seq_len, start_pos=start_pos
|
||||||
|
) # [seq_len, emb_size]
|
||||||
|
|
||||||
# Комбинирование
|
# Комбинирование
|
||||||
out = self._dropout(
|
out = self._dropout(
|
||||||
tok_out + pos_out.unsqueeze(0)
|
tok_out + pos_out.unsqueeze(0)
|
||||||
) # [batch, seq_len, emb_size]
|
) # [batch, seq_len, emb_size]
|
||||||
|
|
||||||
# Стек декодеров
|
# Стек декодеров с передачей кэша
|
||||||
for decoder in self._decoders:
|
new_cache = []
|
||||||
out = decoder(out)
|
for i, decoder in enumerate(self._decoders):
|
||||||
|
decoder_cache = cache[i] if cache is not None else None
|
||||||
|
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
|
||||||
|
|
||||||
return self._linear(out) # [batch, seq_len, vocab_size]
|
# Извлекаем результат из кортежа
|
||||||
|
if use_cache:
|
||||||
|
out, decoder_new_cache = decoder_result
|
||||||
|
new_cache.append(decoder_new_cache)
|
||||||
|
else:
|
||||||
|
out = decoder_result[0]
|
||||||
|
|
||||||
# def forward(self, input_ids, attention_mask=None):
|
logits = self._linear(out) # [batch, seq_len, vocab_size]
|
||||||
# B, T = input_ids.size()
|
|
||||||
# pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
|
# Возвращаем результат с учетом use_cache
|
||||||
#
|
if use_cache:
|
||||||
# x = self.token_emb(input_ids) + self.pos_emb(pos)
|
return (logits, new_cache)
|
||||||
#
|
else:
|
||||||
# for block in self.blocks:
|
return (logits, None)
|
||||||
# x = block(x, attention_mask)
|
|
||||||
#
|
|
||||||
# x = self.ln_f(x)
|
|
||||||
# logits = self.head(x)
|
|
||||||
# return logits
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
@@ -245,12 +274,24 @@ class GPT(BaseModel):
|
|||||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
|
||||||
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
|
||||||
"""
|
"""
|
||||||
|
cache = None
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
# 1. Обрезаем вход, если последовательность слишком длинная
|
# 1. Обрезаем вход, если последовательность слишком длинная
|
||||||
x_cond = x[:, -self._max_seq_len :]
|
if use_cache and cache is not None:
|
||||||
|
# Используем кэш - передаем только последний токен
|
||||||
|
x_input = x[:, -1:] # [batch_size, 1]
|
||||||
|
else:
|
||||||
|
# Первая итерация или кэш отключен - передаем всю последовательность
|
||||||
|
x_input = x
|
||||||
|
|
||||||
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
|
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
|
||||||
logits = self.forward(x_cond)
|
# Прямой проход с кэшем
|
||||||
|
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
|
||||||
|
|
||||||
|
# Обновляем кэш для следующей итерации
|
||||||
|
if use_cache:
|
||||||
|
cache = new_cache
|
||||||
|
|
||||||
# 3. Берем логиты для последнего токена
|
# 3. Берем логиты для последнего токена
|
||||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||||
|
|||||||
@@ -4,17 +4,17 @@ Tests for decoder block.
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from llm.core.decoder import Decoder
|
from llm.core.gpt_decoder import GptDecoder
|
||||||
|
|
||||||
|
|
||||||
class TestDecoder:
|
class TestGptDecoder:
|
||||||
"""Test cases for Decoder."""
|
"""Test cases for Decoder."""
|
||||||
|
|
||||||
def test_initialization(self, embed_dim, num_heads):
|
def test_initialization(self, embed_dim, num_heads):
|
||||||
"""Test that Decoder can be initialized."""
|
"""Test that Decoder can be initialized."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -32,7 +32,7 @@ class TestDecoder:
|
|||||||
"""Test forward pass of Decoder."""
|
"""Test forward pass of Decoder."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -40,7 +40,7 @@ class TestDecoder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = decoder(random_embeddings)
|
output, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -50,7 +50,7 @@ class TestDecoder:
|
|||||||
"""Test forward pass with causal mask."""
|
"""Test forward pass with causal mask."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -62,7 +62,7 @@ class TestDecoder:
|
|||||||
mask = torch.tril(torch.ones(seq_len, seq_len))
|
mask = torch.tril(torch.ones(seq_len, seq_len))
|
||||||
|
|
||||||
# Forward pass with causal mask
|
# Forward pass with causal mask
|
||||||
output = decoder(random_embeddings, mask=mask)
|
output, _ = decoder(random_embeddings, attention_mask=mask)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
assert output.shape == random_embeddings.shape
|
assert output.shape == random_embeddings.shape
|
||||||
@@ -71,14 +71,14 @@ class TestDecoder:
|
|||||||
"""Test that residual connections are properly applied."""
|
"""Test that residual connections are properly applied."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = decoder(random_embeddings)
|
output, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# With residual connections and layer norm, the output shouldn't be
|
# With residual connections and layer norm, the output shouldn't be
|
||||||
# too different from input (in terms of scale/distribution)
|
# too different from input (in terms of scale/distribution)
|
||||||
@@ -92,14 +92,14 @@ class TestDecoder:
|
|||||||
"""Test that layer normalization is applied."""
|
"""Test that layer normalization is applied."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
max_seq_len=max_seq_len,
|
max_seq_len=max_seq_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = decoder(random_embeddings)
|
output, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# Check that output has reasonable statistics (due to layer norm)
|
# Check that output has reasonable statistics (due to layer norm)
|
||||||
# Mean should be close to 0, std close to 1 for each sequence position
|
# Mean should be close to 0, std close to 1 for each sequence position
|
||||||
@@ -114,7 +114,7 @@ class TestDecoder:
|
|||||||
"""Test that gradients flow through Decoder."""
|
"""Test that gradients flow through Decoder."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -122,7 +122,7 @@ class TestDecoder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = decoder(random_embeddings)
|
output, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# Create a dummy loss and backward pass
|
# Create a dummy loss and backward pass
|
||||||
loss = output.sum()
|
loss = output.sum()
|
||||||
@@ -139,7 +139,7 @@ class TestDecoder:
|
|||||||
"""Test that Decoder works on correct device."""
|
"""Test that Decoder works on correct device."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -148,7 +148,7 @@ class TestDecoder:
|
|||||||
inputs = random_embeddings.to(device)
|
inputs = random_embeddings.to(device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
output = decoder(inputs)
|
output, _ = decoder(inputs)
|
||||||
|
|
||||||
# Check device consistency
|
# Check device consistency
|
||||||
assert output.device == device
|
assert output.device == device
|
||||||
@@ -165,7 +165,7 @@ class TestDecoder:
|
|||||||
for embed_dim, num_heads in test_cases:
|
for embed_dim, num_heads in test_cases:
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -174,7 +174,7 @@ class TestDecoder:
|
|||||||
batch_size, seq_len = 2, 16
|
batch_size, seq_len = 2, 16
|
||||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||||
|
|
||||||
output = decoder(inputs)
|
output, _ = decoder(inputs)
|
||||||
|
|
||||||
assert output.shape == inputs.shape
|
assert output.shape == inputs.shape
|
||||||
|
|
||||||
@@ -183,7 +183,7 @@ class TestDecoder:
|
|||||||
"""Test Decoder with different input shapes."""
|
"""Test Decoder with different input shapes."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -191,7 +191,7 @@ class TestDecoder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
inputs = torch.randn(batch_size, seq_len, embed_dim)
|
||||||
output = decoder(inputs)
|
output, _ = decoder(inputs)
|
||||||
|
|
||||||
assert output.shape == (batch_size, seq_len, embed_dim)
|
assert output.shape == (batch_size, seq_len, embed_dim)
|
||||||
|
|
||||||
@@ -199,7 +199,7 @@ class TestDecoder:
|
|||||||
"""Test that Decoder behaves differently in train vs eval mode."""
|
"""Test that Decoder behaves differently in train vs eval mode."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -209,11 +209,11 @@ class TestDecoder:
|
|||||||
|
|
||||||
# Training mode
|
# Training mode
|
||||||
decoder.train()
|
decoder.train()
|
||||||
output_train = decoder(random_embeddings)
|
output_train, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# Evaluation mode
|
# Evaluation mode
|
||||||
decoder.eval()
|
decoder.eval()
|
||||||
output_eval = decoder(random_embeddings)
|
output_eval, _ = decoder(random_embeddings)
|
||||||
|
|
||||||
# Outputs should be different due to dropout
|
# Outputs should be different due to dropout
|
||||||
assert not torch.allclose(output_train, output_eval)
|
assert not torch.allclose(output_train, output_eval)
|
||||||
@@ -222,7 +222,7 @@ class TestDecoder:
|
|||||||
"""Test that parameters are properly initialized."""
|
"""Test that parameters are properly initialized."""
|
||||||
head_size = embed_dim // num_heads
|
head_size = embed_dim // num_heads
|
||||||
max_seq_len = 1024
|
max_seq_len = 1024
|
||||||
decoder = Decoder(
|
decoder = GptDecoder(
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
emb_size=embed_dim,
|
emb_size=embed_dim,
|
||||||
head_size=head_size,
|
head_size=head_size,
|
||||||
@@ -30,7 +30,7 @@ class TestGPT:
|
|||||||
model = GPT(gpt_config)
|
model = GPT(gpt_config)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = model(random_inputs)
|
logits, _ = model(random_inputs)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
batch_size, seq_len = random_inputs.shape
|
batch_size, seq_len = random_inputs.shape
|
||||||
@@ -45,7 +45,7 @@ class TestGPT:
|
|||||||
model = GPT(gpt_config)
|
model = GPT(gpt_config)
|
||||||
|
|
||||||
# Forward pass with mask
|
# Forward pass with mask
|
||||||
logits = model(random_inputs, attention_mask=attention_mask)
|
logits, _ = model(random_inputs, attention_mask=attention_mask)
|
||||||
|
|
||||||
# Check output shape
|
# Check output shape
|
||||||
batch_size, seq_len = random_inputs.shape
|
batch_size, seq_len = random_inputs.shape
|
||||||
@@ -132,7 +132,7 @@ class TestGPT:
|
|||||||
model = GPT(gpt_config)
|
model = GPT(gpt_config)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = model(random_inputs)
|
logits, _ = model(random_inputs)
|
||||||
|
|
||||||
# Create a dummy loss and backward pass
|
# Create a dummy loss and backward pass
|
||||||
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
|
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
|
||||||
@@ -157,7 +157,7 @@ class TestGPT:
|
|||||||
inputs = random_inputs.to(device)
|
inputs = random_inputs.to(device)
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
logits = model(inputs)
|
logits, _ = model(inputs)
|
||||||
|
|
||||||
# Check device consistency
|
# Check device consistency
|
||||||
assert logits.device == device
|
assert logits.device == device
|
||||||
@@ -197,7 +197,7 @@ class TestGPT:
|
|||||||
batch_size, seq_len = 2, 16
|
batch_size, seq_len = 2, 16
|
||||||
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||||
|
|
||||||
logits = model(inputs)
|
logits, _ = model(inputs)
|
||||||
|
|
||||||
expected_shape = (batch_size, seq_len, config["vocab_size"])
|
expected_shape = (batch_size, seq_len, config["vocab_size"])
|
||||||
assert logits.shape == expected_shape
|
assert logits.shape == expected_shape
|
||||||
@@ -208,7 +208,7 @@ class TestGPT:
|
|||||||
model = GPT(gpt_config)
|
model = GPT(gpt_config)
|
||||||
|
|
||||||
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
|
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
|
||||||
logits = model(inputs)
|
logits, _ = model(inputs)
|
||||||
|
|
||||||
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
|
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
|
||||||
assert logits.shape == expected_shape
|
assert logits.shape == expected_shape
|
||||||
@@ -219,11 +219,11 @@ class TestGPT:
|
|||||||
|
|
||||||
# Training mode
|
# Training mode
|
||||||
model.train()
|
model.train()
|
||||||
output_train = model(random_inputs)
|
output_train, _ = model(random_inputs)
|
||||||
|
|
||||||
# Evaluation mode
|
# Evaluation mode
|
||||||
model.eval()
|
model.eval()
|
||||||
output_eval = model(random_inputs)
|
output_eval, _ = model(random_inputs)
|
||||||
|
|
||||||
# Outputs should be different due to dropout
|
# Outputs should be different due to dropout
|
||||||
assert not torch.allclose(output_train, output_eval)
|
assert not torch.allclose(output_train, output_eval)
|
||||||
@@ -271,7 +271,7 @@ class TestGPT:
|
|||||||
"""Test that GPT output has proper distribution."""
|
"""Test that GPT output has proper distribution."""
|
||||||
model = GPT(gpt_config)
|
model = GPT(gpt_config)
|
||||||
|
|
||||||
logits = model(random_inputs)
|
logits, _ = model(random_inputs)
|
||||||
|
|
||||||
# Logits should not have extreme values
|
# Logits should not have extreme values
|
||||||
assert logits.abs().max() < 100
|
assert logits.abs().max() < 100
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def test_gpt_model_creation():
|
|||||||
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_ids)
|
logits, _ = model(input_ids)
|
||||||
|
|
||||||
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
|
||||||
print("✅ GPT model creation and forward pass test passed")
|
print("✅ GPT model creation and forward pass test passed")
|
||||||
@@ -222,7 +222,7 @@ def test_gpt_with_tokenizer():
|
|||||||
input_ids = torch.tensor([tokens])
|
input_ids = torch.tensor([tokens])
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_ids)
|
logits, _ = model(input_ids)
|
||||||
|
|
||||||
assert logits.shape == (1, len(tokens), vocab_size)
|
assert logits.shape == (1, len(tokens), vocab_size)
|
||||||
print("✅ GPT with tokenizer integration test passed")
|
print("✅ GPT with tokenizer integration test passed")
|
||||||
|
|||||||
Reference in New Issue
Block a user