diff --git a/llm/src/llm/core/decoder.py b/llm/src/llm/core/gpt_decoder.py similarity index 92% rename from llm/src/llm/core/decoder.py rename to llm/src/llm/core/gpt_decoder.py index 11de653..1960b6a 100644 --- a/llm/src/llm/core/decoder.py +++ b/llm/src/llm/core/gpt_decoder.py @@ -4,7 +4,7 @@ from .feed_forward import FeedForward from .multi_head_attention import MultiHeadAttention -class Decoder(nn.Module): +class GptDecoder(nn.Module): """ Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей. @@ -94,7 +94,13 @@ class Decoder(nn.Module): self._norm1 = nn.LayerNorm(emb_size) self._norm2 = nn.LayerNorm(emb_size) - def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + use_cache: bool = False, + cache: list = None, + attention_mask=None + ) -> tuple: """ Один прямой проход через Transformer decoder block. @@ -117,10 +123,16 @@ class Decoder(nn.Module): - Применяем FFN к нормализованному результату (layernorm) - Добавляем residual-связь (ffn + предыдущий выход) """ + # Self-Attention блок - attention, _ = self._heads(x, mask, use_cache=False, cache=None) + attention, kv_caches = self._heads(x, attention_mask, use_cache=use_cache, cache=cache) out = self._norm1(attention + x) # FeedForward блок ffn_out = self._ff(out) - return self._norm2(ffn_out + out) + result = self._norm2(ffn_out + out) + + if use_cache: + return (result, kv_caches) + else: + return (result, None) diff --git a/llm/src/llm/models/gpt/gpt.py b/llm/src/llm/models/gpt/gpt.py index d3d5dfc..cd947c1 100644 --- a/llm/src/llm/models/gpt/gpt.py +++ b/llm/src/llm/models/gpt/gpt.py @@ -26,7 +26,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import Optional, Dict from llm.core.base_model import BaseModel -from llm.core.decoder import Decoder +from llm.core.gpt_decoder import GptDecoder from llm.core.token_embeddings import TokenEmbeddings from llm.core.positional_embeddings import PositionalEmbeddings @@ -116,7 +116,7 @@ class GPT(BaseModel): # head_size = emb_size // num_heads self._decoders = nn.ModuleList( [ - Decoder( + GptDecoder( num_heads=config["num_heads"], emb_size=config["embed_dim"], head_size=config["embed_dim"] // config["num_heads"], @@ -133,7 +133,9 @@ class GPT(BaseModel): """Возвращает максимальную длину последовательности.""" return self._max_seq_len - def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attention_mask=None, use_cache: bool = True, cache: list = None + ) -> tuple: """ Прямой проход для получения логитов по последовательности токенов. @@ -157,33 +159,60 @@ class GPT(BaseModel): f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}" ) + # Вычисление start_pos из кэша (если кэш передан) + if cache is not None: + seq_len = 1 + # Безопасно извлекаем key_cache для вычисления start_pos + if ( + isinstance(cache, (list, tuple)) + and len(cache) > 0 + and cache[0] is not None + and isinstance(cache[0], (list, tuple)) + and len(cache[0]) > 0 + and cache[0][0] is not None + and isinstance(cache[0][0], (tuple, list)) + and len(cache[0][0]) > 0 + ): + key_cache, _ = cache[0][0] + start_pos = key_cache.size(1) + else: + start_pos = 0 + else: + # Без кэша работаем как раньше + start_pos = 0 + seq_len = x.size(1) + # Эмбеддинги токенов и позиций tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size] - pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size] + pos_out = self._position_embeddings( + seq_len, start_pos=start_pos + ) # [seq_len, emb_size] # Комбинирование out = self._dropout( tok_out + pos_out.unsqueeze(0) ) # [batch, seq_len, emb_size] - # Стек декодеров - for decoder in self._decoders: - out = decoder(out) + # Стек декодеров с передачей кэша + new_cache = [] + for i, decoder in enumerate(self._decoders): + decoder_cache = cache[i] if cache is not None else None + decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache) - return self._linear(out) # [batch, seq_len, vocab_size] + # Извлекаем результат из кортежа + if use_cache: + out, decoder_new_cache = decoder_result + new_cache.append(decoder_new_cache) + else: + out = decoder_result[0] - # def forward(self, input_ids, attention_mask=None): - # B, T = input_ids.size() - # pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0) - # - # x = self.token_emb(input_ids) + self.pos_emb(pos) - # - # for block in self.blocks: - # x = block(x, attention_mask) - # - # x = self.ln_f(x) - # logits = self.head(x) - # return logits + logits = self._linear(out) # [batch, seq_len, vocab_size] + + # Возвращаем результат с учетом use_cache + if use_cache: + return (logits, new_cache) + else: + return (logits, None) def generate( self, @@ -245,12 +274,24 @@ class GPT(BaseModel): - Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751 - Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf """ + cache = None + for _ in range(max_new_tokens): # 1. Обрезаем вход, если последовательность слишком длинная - x_cond = x[:, -self._max_seq_len :] + if use_cache and cache is not None: + # Используем кэш - передаем только последний токен + x_input = x[:, -1:] # [batch_size, 1] + else: + # Первая итерация или кэш отключен - передаем всю последовательность + x_input = x # 2. Передаем последовательность в метод forward класса GPT и полуаем логиты. - logits = self.forward(x_cond) + # Прямой проход с кэшем + logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache) + + # Обновляем кэш для следующей итерации + if use_cache: + cache = new_cache # 3. Берем логиты для последнего токена last_logits = logits[:, -1, :] # [batch_size, vocab_size] diff --git a/llm/tests/core/test_decoder.py b/llm/tests/core/test_gpt_decoder.py similarity index 89% rename from llm/tests/core/test_decoder.py rename to llm/tests/core/test_gpt_decoder.py index 8eae46f..d1632c1 100644 --- a/llm/tests/core/test_decoder.py +++ b/llm/tests/core/test_gpt_decoder.py @@ -4,17 +4,17 @@ Tests for decoder block. import pytest import torch -from llm.core.decoder import Decoder +from llm.core.gpt_decoder import GptDecoder -class TestDecoder: +class TestGptDecoder: """Test cases for Decoder.""" def test_initialization(self, embed_dim, num_heads): """Test that Decoder can be initialized.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -32,7 +32,7 @@ class TestDecoder: """Test forward pass of Decoder.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -40,7 +40,7 @@ class TestDecoder: ) # Forward pass - output = decoder(random_embeddings) + output, _ = decoder(random_embeddings) # Check output shape assert output.shape == random_embeddings.shape @@ -50,7 +50,7 @@ class TestDecoder: """Test forward pass with causal mask.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -62,7 +62,7 @@ class TestDecoder: mask = torch.tril(torch.ones(seq_len, seq_len)) # Forward pass with causal mask - output = decoder(random_embeddings, mask=mask) + output, _ = decoder(random_embeddings, attention_mask=mask) # Check output shape assert output.shape == random_embeddings.shape @@ -71,14 +71,14 @@ class TestDecoder: """Test that residual connections are properly applied.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len, ) - output = decoder(random_embeddings) + output, _ = decoder(random_embeddings) # With residual connections and layer norm, the output shouldn't be # too different from input (in terms of scale/distribution) @@ -92,14 +92,14 @@ class TestDecoder: """Test that layer normalization is applied.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len, ) - output = decoder(random_embeddings) + output, _ = decoder(random_embeddings) # Check that output has reasonable statistics (due to layer norm) # Mean should be close to 0, std close to 1 for each sequence position @@ -114,7 +114,7 @@ class TestDecoder: """Test that gradients flow through Decoder.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -122,7 +122,7 @@ class TestDecoder: ) # Forward pass - output = decoder(random_embeddings) + output, _ = decoder(random_embeddings) # Create a dummy loss and backward pass loss = output.sum() @@ -139,7 +139,7 @@ class TestDecoder: """Test that Decoder works on correct device.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -148,7 +148,7 @@ class TestDecoder: inputs = random_embeddings.to(device) # Forward pass - output = decoder(inputs) + output, _ = decoder(inputs) # Check device consistency assert output.device == device @@ -165,7 +165,7 @@ class TestDecoder: for embed_dim, num_heads in test_cases: head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -174,7 +174,7 @@ class TestDecoder: batch_size, seq_len = 2, 16 inputs = torch.randn(batch_size, seq_len, embed_dim) - output = decoder(inputs) + output, _ = decoder(inputs) assert output.shape == inputs.shape @@ -183,7 +183,7 @@ class TestDecoder: """Test Decoder with different input shapes.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -191,7 +191,7 @@ class TestDecoder: ) inputs = torch.randn(batch_size, seq_len, embed_dim) - output = decoder(inputs) + output, _ = decoder(inputs) assert output.shape == (batch_size, seq_len, embed_dim) @@ -199,7 +199,7 @@ class TestDecoder: """Test that Decoder behaves differently in train vs eval mode.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, @@ -209,11 +209,11 @@ class TestDecoder: # Training mode decoder.train() - output_train = decoder(random_embeddings) + output_train, _ = decoder(random_embeddings) # Evaluation mode decoder.eval() - output_eval = decoder(random_embeddings) + output_eval, _ = decoder(random_embeddings) # Outputs should be different due to dropout assert not torch.allclose(output_train, output_eval) @@ -222,7 +222,7 @@ class TestDecoder: """Test that parameters are properly initialized.""" head_size = embed_dim // num_heads max_seq_len = 1024 - decoder = Decoder( + decoder = GptDecoder( num_heads=num_heads, emb_size=embed_dim, head_size=head_size, diff --git a/llm/tests/models/test_gpt.py b/llm/tests/models/test_gpt.py index 61f90d2..48a0101 100644 --- a/llm/tests/models/test_gpt.py +++ b/llm/tests/models/test_gpt.py @@ -30,7 +30,7 @@ class TestGPT: model = GPT(gpt_config) # Forward pass - logits = model(random_inputs) + logits, _ = model(random_inputs) # Check output shape batch_size, seq_len = random_inputs.shape @@ -45,7 +45,7 @@ class TestGPT: model = GPT(gpt_config) # Forward pass with mask - logits = model(random_inputs, attention_mask=attention_mask) + logits, _ = model(random_inputs, attention_mask=attention_mask) # Check output shape batch_size, seq_len = random_inputs.shape @@ -132,7 +132,7 @@ class TestGPT: model = GPT(gpt_config) # Forward pass - logits = model(random_inputs) + logits, _ = model(random_inputs) # Create a dummy loss and backward pass targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape) @@ -157,7 +157,7 @@ class TestGPT: inputs = random_inputs.to(device) # Forward pass - logits = model(inputs) + logits, _ = model(inputs) # Check device consistency assert logits.device == device @@ -197,7 +197,7 @@ class TestGPT: batch_size, seq_len = 2, 16 inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len)) - logits = model(inputs) + logits, _ = model(inputs) expected_shape = (batch_size, seq_len, config["vocab_size"]) assert logits.shape == expected_shape @@ -208,7 +208,7 @@ class TestGPT: model = GPT(gpt_config) inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len)) - logits = model(inputs) + logits, _ = model(inputs) expected_shape = (batch_size, seq_len, gpt_config["vocab_size"]) assert logits.shape == expected_shape @@ -219,11 +219,11 @@ class TestGPT: # Training mode model.train() - output_train = model(random_inputs) + output_train, _ = model(random_inputs) # Evaluation mode model.eval() - output_eval = model(random_inputs) + output_eval, _ = model(random_inputs) # Outputs should be different due to dropout assert not torch.allclose(output_train, output_eval) @@ -271,7 +271,7 @@ class TestGPT: """Test that GPT output has proper distribution.""" model = GPT(gpt_config) - logits = model(random_inputs) + logits, _ = model(random_inputs) # Logits should not have extreme values assert logits.abs().max() < 100 diff --git a/llm/tests/test_basic.py b/llm/tests/test_basic.py index 8d18689..0565d22 100644 --- a/llm/tests/test_basic.py +++ b/llm/tests/test_basic.py @@ -28,7 +28,7 @@ def test_gpt_model_creation(): input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)) with torch.no_grad(): - logits = model(input_ids) + logits, _ = model(input_ids) assert logits.shape == (batch_size, seq_len, config["vocab_size"]) print("✅ GPT model creation and forward pass test passed") @@ -222,7 +222,7 @@ def test_gpt_with_tokenizer(): input_ids = torch.tensor([tokens]) with torch.no_grad(): - logits = model(input_ids) + logits, _ = model(input_ids) assert logits.shape == (1, len(tokens), vocab_size) print("✅ GPT with tokenizer integration test passed")