From 923aa51e2a27e7103e08343f1adf3c56960cb08e Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Thu, 16 Oct 2025 12:30:53 +0300 Subject: [PATCH] docs(core): add docstrings and unit tests for CachedDecoder module - docs: Add detailed docstrings for CachedDecoder class and its methods (__init__, forward); explain autoregressive caching, architecture, math, usage, and links to GPT-2/LLM references - test: Add comprehensive unit tests for CachedDecoder (initialization, forward with and without cache, cache chaining, output shape, error on long input, backward pass) - These changes improve code clarity, reliability, and testing for decoder blocks with KV cache. --- llm/src/llm/core/cached_decoder.py | 122 ++++++++++++++++---------- llm/tests/core/test_cached_decoder.py | 65 ++++++++++++++ 2 files changed, 139 insertions(+), 48 deletions(-) create mode 100644 llm/tests/core/test_cached_decoder.py diff --git a/llm/src/llm/core/cached_decoder.py b/llm/src/llm/core/cached_decoder.py index 7cac0ff..c149b3c 100644 --- a/llm/src/llm/core/cached_decoder.py +++ b/llm/src/llm/core/cached_decoder.py @@ -9,33 +9,46 @@ from .rope import RoPE class CachedDecoder(nn.Module): """ - Универсальный декодерный блок для современных LLM (GPT, LLaMA, др.), поддерживает кэширование key-value для эффективной генерации. + CachedDecoder — Transformer-декодер с key/value-кэшированием (реализация накладывающегося masked multi-head attention). - Научная идея: - Автопагрессивная авторегрессия в трансформерах требует быстрого доступа к ранее вычисленным self-attention ключам/значениям — этот класс позволяет прозрачно кэшировать такие состояния для быстрой инференс-генерации. + Назначение: + ----------- + Позволяет быстро и эффективно реализовывать autoregressive генерацию текста в стиле GPT-2/3/4: + - На шаге генерации используются только нужные токены, “прошлые” key/value значения не пересчитываются, а подаются из кэша. + - Позволяет значительно ускорять inferece (особенно на длинных последовательностях). + - Вдохновлено реализациями в HuggingFace transformers, GPT-2/3 и других LLM. - Алгоритм: - - Input -> LayerNorm -> Многоголовое внимание с кэшем (может быть RoPE) - - Суммируем residual - - LayerNorm -> FeedForward (любой, например SwiGLU) -> Residual - - Возвращается кортеж (output, kvcache) + Архитектурные особенности: + -------------------------- + - Использует классическую multi-head attention (с causal mask — запрещает видеть “будущее”). + - Предусматривает передачу и накопление KV-cache для каждого слоя (hidden state attention). + - Поддерживает передачу внимания через стек attention-блоков. + - Применяется layernorm и feed-forward block (GELU). - Args: - feed_forward_layer (nn.Module): FeedForward или SwiGLU слой - num_heads (int): Количество голов внимания - emb_size (int): Размерность эмбеддингов - head_size (int): Размерность головы внимания - max_seq_len (int): Максимальная длина - norm_layer (тип nn.Module): Normalization слой (LayerNorm или RMSNorm) - dropout (float): Dropout - rope (RoPE|None): Экземпляр RoPE (для LLaMA) + Параметры конструктора: + ----------------------- + num_heads : int — число attention heads + emb_size : int — embedding размерность + head_size : int — размер каждой attention head (обычно emb_size // num_heads) + feed_forward_layer : nn.Module — feedforward блок (mLP), может быть любым PyTorch-слоем + max_seq_len : int — максимально допустимая длина последовательности + dropout : float — dropout на attention/ffn + + Пример использования: + --------------------- + >>> from llm.core.feed_forward import FeedForward + >>> ff_block = FeedForward(emb_size=256, dropout=0.1, activation=\"gelu\") + >>> decoder = CachedDecoder(num_heads=4, emb_size=256, head_size=64, feed_forward_layer=ff_block, max_seq_len=2048, dropout=0.1) + >>> x = torch.randn(2, 100, 256) + >>> y, kv_cache = decoder(x, use_cache=True, cache=None) + >>> print(y.shape) # torch.Size([2, 100, 256]) + + Подробнее: + ---------- + - GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf + - HuggingFace cache mechanics: https://huggingface.co/docs/transformers/main/en/model_doc/gpt2 + - Объяснения autoregressive cache: https://jalammar.github.io/illustrated-gpt2/ - Пример (GPT2 style): - >>> decoder = CachedDecoder( - ... feed_forward_layer=FeedForward(...), - ... norm_layer=nn.LayerNorm, - ... num_heads=4, emb_size=256, head_size=64, max_seq_len=128) - >>> out, cache = decoder(x, use_cache=True) """ def __init__( @@ -50,20 +63,22 @@ class CachedDecoder(nn.Module): rope: RoPE = None, ): """ - Инициализация декодера с кэшированием. + Конструктор CachedDecoder. - Поведение аналогично блоку TransformerDecoderLayer, - но с гибкой возможностью подмены любых подкомпонент (активация, norm, позиции). - - Args: - feed_forward_layer: Слой feed-forward (должен быть экземпляром, а не классом) - num_heads: Количество голов внимания - emb_size: Размерность эмбеддингов - head_size: Размерность каждой головы - max_seq_len: Максимальная длина последовательности - norm_layer: Класс нормализации (по умолчанию LayerNorm) - dropout: Вероятность dropout - rope: Rotary Positional Embeddings (опционально) + Аргументы: + ---------- + num_heads : int + Сколько attention heads используется в каждом attention слое. + emb_size : int + Размерность входного вектора x. + head_size : int + Размерность каждой attention head; emb_size = num_heads * head_size должно быть True! + feed_forward_layer : nn.Module + Feed-forward слой (например, обычный двухслойный MLP), который применяется после нормы и внимания, и после второй нормы. + max_seq_len : int + Максимальная поддерживаемая длина последовательности (выделяет буфер для causal-маски). + dropout : float, default=0.1 + Dropout после внимания и/или feedforward. """ super().__init__() self._heads = MultiHeadAttention( @@ -86,19 +101,30 @@ class CachedDecoder(nn.Module): cache: list = None, ): """ - Прямой проход с поддержкой кэша. + Прямой проход через Decoder Block с поддержкой KV-кэша. + + В этом методе применяется: + - Causal multi-head attention (masked, не смотрит вперёд) + - Быстрая обработка длинных последовательностей за счёт сохранения и передачи KV-кэша + - LayerNorm перед каждым блоком + - Feed-forward блок и вторая LayerNorm + - Dropout + + Аргументы: + ---------- + x : torch.Tensor + Вход [batch, seq_len, emb_size] + use_cache : bool, по умолчанию True + Включать ли накопление и возврат KV-кэша для autoregressive inferece. + cache : list, опционально + Список предыдущего KV-кеша для attention. + + Возвращает: + ----------- + x_ff_out : torch.Tensor + Результат после attention, модуля и их рез. связей (shape == x) + new_cache : new KV-cache (или None) - Args: - x (Tensor[float]): [batch, seq_len, emb_size] — скрытые состояния - mask (Optional[Tensor]): маска внимания (или causal mask), shape [seq_len, seq_len] - use_cache (bool): использовать кэширование KV - cache (list): кэш self-attention для быстрого авторегрессива - Returns: - output (Tensor[float]): выходные состояния [batch, seq_len, emb_size] - kv_caches (list): обновленный кэш, если use_cache - Пример: - >>> out, new_cache = decoder(x, use_cache=True, cache=old_cache) - >>> out.shape # [batch, seq_len, emb_size] """ norm1_out = self._norm1(x) # Передаём все cache/use_cache дальше в attention diff --git a/llm/tests/core/test_cached_decoder.py b/llm/tests/core/test_cached_decoder.py new file mode 100644 index 0000000..8b05032 --- /dev/null +++ b/llm/tests/core/test_cached_decoder.py @@ -0,0 +1,65 @@ +import torch +import pytest +from llm.core.cached_decoder import CachedDecoder +from llm.core.feed_forward import FeedForward + +@pytest.fixture +def decoder_config(): + return dict( + num_heads=4, + emb_size=32, + head_size=8, + feed_forward_layer=FeedForward(emb_size=32, dropout=0.1, activation="gelu"), + max_seq_len=64, + dropout=0.1 + ) + +def test_cached_decoder_init(decoder_config): + model = CachedDecoder(**decoder_config) + assert model is not None + # Main attention block is usually stored as _heads or _attention (which itself includes _q _k _v) + assert hasattr(model, '_heads') or hasattr(model, '_attention') + assert hasattr(model, '_ff') or hasattr(model, 'feed_forward_layer') + +def test_cached_decoder_forward_shape(decoder_config): + model = CachedDecoder(**decoder_config) + batch, seq_len, emb_size = 3, 10, decoder_config['emb_size'] + x = torch.randn(batch, seq_len, emb_size) + output, cache = model(x, use_cache=True) + assert output.shape == (batch, seq_len, emb_size) + assert cache is not None + +def test_cached_decoder_forward_no_cache(decoder_config): + model = CachedDecoder(**decoder_config) + batch, seq_len, emb_size = 2, 12, decoder_config['emb_size'] + x = torch.randn(batch, seq_len, emb_size) + output, cache = model(x, use_cache=False) + assert output.shape == (batch, seq_len, emb_size) + assert cache is None + +def test_cached_decoder_error_on_long_seq(decoder_config): + model = CachedDecoder(**decoder_config) + batch, seq_len, emb_size = 1, decoder_config['max_seq_len'] + 1, decoder_config['emb_size'] + x = torch.randn(batch, seq_len, emb_size) + with pytest.raises(ValueError): + model(x) + +def test_cached_decoder_backward(decoder_config): + model = CachedDecoder(**decoder_config) + batch, seq_len, emb_size = 2, 7, decoder_config['emb_size'] + x = torch.randn(batch, seq_len, emb_size, requires_grad=True) + output, cache = model(x) + loss = output.sum() + loss.backward() + assert x.grad is not None + +def test_cached_decoder_kv_cache_chain(decoder_config): + model = CachedDecoder(**decoder_config) + batch, seq_len, emb_size = 1, 4, decoder_config['emb_size'] + x = torch.randn(batch, seq_len, emb_size) + # Первый проход — кэша нет + _, cache = model(x, use_cache=True) + # Второй проход — передаём кэш, добавляем еще токен: + next_x = torch.randn(batch, 1, emb_size) + _, cache2 = model(next_x, use_cache=True, cache=cache) + assert cache2 is not None