docs(core): add docstrings and unit tests for CachedDecoder module

- docs: Add detailed docstrings for CachedDecoder class and its methods (__init__, forward); explain autoregressive caching, architecture, math, usage, and links to GPT-2/LLM references
- test: Add comprehensive unit tests for CachedDecoder (initialization, forward with and without cache, cache chaining, output shape, error on long input, backward pass)
- These changes improve code clarity, reliability, and testing for decoder blocks with KV cache.
This commit is contained in:
Sergey Penkovsky
2025-10-16 12:30:53 +03:00
parent ba3b04cec2
commit 923aa51e2a
2 changed files with 139 additions and 48 deletions

View File

@@ -9,33 +9,46 @@ from .rope import RoPE
class CachedDecoder(nn.Module): class CachedDecoder(nn.Module):
""" """
Универсальный декодерный блок для современных LLM (GPT, LLaMA, др.), поддерживает кэширование key-value для эффективной генерации. CachedDecoder — Transformer-декодер с key/value-кэшированием (реализация накладывающегося masked multi-head attention).
Научная идея: Назначение:
Автопагрессивная авторегрессия в трансформерах требует быстрого доступа к ранее вычисленным self-attention ключам/значениям — этот класс позволяет прозрачно кэшировать такие состояния для быстрой инференс-генерации. -----------
Позволяет быстро и эффективно реализовывать autoregressive генерацию текста в стиле GPT-2/3/4:
- На шаге генерации используются только нужные токены, “прошлые” key/value значения не пересчитываются, а подаются из кэша.
- Позволяет значительно ускорять inferece (особенно на длинных последовательностях).
- Вдохновлено реализациями в HuggingFace transformers, GPT-2/3 и других LLM.
Алгоритм: Архитектурные особенности:
- Input -> LayerNorm -> Многоголовое внимание с кэшем (может быть RoPE) --------------------------
- Суммируем residual - Использует классическую multi-head attention (с causal mask — запрещает видеть “будущее”).
- LayerNorm -> FeedForward (любой, например SwiGLU) -> Residual - Предусматривает передачу и накопление KV-cache для каждого слоя (hidden state attention).
- Возвращается кортеж (output, kvcache) - Поддерживает передачу внимания через стек attention-блоков.
- Применяется layernorm и feed-forward block (GELU).
Args: Параметры конструктора:
feed_forward_layer (nn.Module): FeedForward или SwiGLU слой -----------------------
num_heads (int): Количество голов внимания num_heads : int — число attention heads
emb_size (int): Размерность эмбеддингов emb_size : int — embedding размерность
head_size (int): Размерность головы внимания head_size : intразмер каждой attention head (обычно emb_size // num_heads)
max_seq_len (int): Максимальная длина feed_forward_layer : nn.Module — feedforward блок (mLP), может быть любым PyTorch-слоем
norm_layer (тип nn.Module): Normalization слой (LayerNorm или RMSNorm) max_seq_len : int — максимально допустимая длина последовательности
dropout (float): Dropout dropout : float — dropout на attention/ffn
rope (RoPE|None): Экземпляр RoPE (для LLaMA)
Пример использования:
---------------------
>>> from llm.core.feed_forward import FeedForward
>>> ff_block = FeedForward(emb_size=256, dropout=0.1, activation=\"gelu\")
>>> decoder = CachedDecoder(num_heads=4, emb_size=256, head_size=64, feed_forward_layer=ff_block, max_seq_len=2048, dropout=0.1)
>>> x = torch.randn(2, 100, 256)
>>> y, kv_cache = decoder(x, use_cache=True, cache=None)
>>> print(y.shape) # torch.Size([2, 100, 256])
Подробнее:
----------
- GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
- HuggingFace cache mechanics: https://huggingface.co/docs/transformers/main/en/model_doc/gpt2
- Объяснения autoregressive cache: https://jalammar.github.io/illustrated-gpt2/
Пример (GPT2 style):
>>> decoder = CachedDecoder(
... feed_forward_layer=FeedForward(...),
... norm_layer=nn.LayerNorm,
... num_heads=4, emb_size=256, head_size=64, max_seq_len=128)
>>> out, cache = decoder(x, use_cache=True)
""" """
def __init__( def __init__(
@@ -50,20 +63,22 @@ class CachedDecoder(nn.Module):
rope: RoPE = None, rope: RoPE = None,
): ):
""" """
Инициализация декодера с кэшированием. Конструктор CachedDecoder.
Поведение аналогично блоку TransformerDecoderLayer, Аргументы:
но с гибкой возможностью подмены любых подкомпонент (активация, norm, позиции). ----------
num_heads : int
Args: Сколько attention heads используется в каждом attention слое.
feed_forward_layer: Слой feed-forward (должен быть экземпляром, а не классом) emb_size : int
num_heads: Количество голов внимания Размерность входного вектора x.
emb_size: Размерность эмбеддингов head_size : int
head_size: Размерность каждой головы Размерность каждой attention head; emb_size = num_heads * head_size должно быть True!
max_seq_len: Максимальная длина последовательности feed_forward_layer : nn.Module
norm_layer: Класс нормализации (по умолчанию LayerNorm) Feed-forward слой (например, обычный двухслойный MLP), который применяется после нормы и внимания, и после второй нормы.
dropout: Вероятность dropout max_seq_len : int
rope: Rotary Positional Embeddings (опционально) Максимальная поддерживаемая длина последовательности (выделяет буфер для causal-маски).
dropout : float, default=0.1
Dropout после внимания и/или feedforward.
""" """
super().__init__() super().__init__()
self._heads = MultiHeadAttention( self._heads = MultiHeadAttention(
@@ -86,19 +101,30 @@ class CachedDecoder(nn.Module):
cache: list = None, cache: list = None,
): ):
""" """
Прямой проход с поддержкой кэша. Прямой проход через Decoder Block с поддержкой KV-кэша.
В этом методе применяется:
- Causal multi-head attention (masked, не смотрит вперёд)
- Быстрая обработка длинных последовательностей за счёт сохранения и передачи KV-кэша
- LayerNorm перед каждым блоком
- Feed-forward блок и вторая LayerNorm
- Dropout
Аргументы:
----------
x : torch.Tensor
Вход [batch, seq_len, emb_size]
use_cache : bool, по умолчанию True
Включать ли накопление и возврат KV-кэша для autoregressive inferece.
cache : list, опционально
Список предыдущего KV-кеша для attention.
Возвращает:
-----------
x_ff_out : torch.Tensor
Результат после attention, модуля и их рез. связей (shape == x)
new_cache : new KV-cache (или None)
Args:
x (Tensor[float]): [batch, seq_len, emb_size] — скрытые состояния
mask (Optional[Tensor]): маска внимания (или causal mask), shape [seq_len, seq_len]
use_cache (bool): использовать кэширование KV
cache (list): кэш self-attention для быстрого авторегрессива
Returns:
output (Tensor[float]): выходные состояния [batch, seq_len, emb_size]
kv_caches (list): обновленный кэш, если use_cache
Пример:
>>> out, new_cache = decoder(x, use_cache=True, cache=old_cache)
>>> out.shape # [batch, seq_len, emb_size]
""" """
norm1_out = self._norm1(x) norm1_out = self._norm1(x)
# Передаём все cache/use_cache дальше в attention # Передаём все cache/use_cache дальше в attention

View File

@@ -0,0 +1,65 @@
import torch
import pytest
from llm.core.cached_decoder import CachedDecoder
from llm.core.feed_forward import FeedForward
@pytest.fixture
def decoder_config():
return dict(
num_heads=4,
emb_size=32,
head_size=8,
feed_forward_layer=FeedForward(emb_size=32, dropout=0.1, activation="gelu"),
max_seq_len=64,
dropout=0.1
)
def test_cached_decoder_init(decoder_config):
model = CachedDecoder(**decoder_config)
assert model is not None
# Main attention block is usually stored as _heads or _attention (which itself includes _q _k _v)
assert hasattr(model, '_heads') or hasattr(model, '_attention')
assert hasattr(model, '_ff') or hasattr(model, 'feed_forward_layer')
def test_cached_decoder_forward_shape(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 3, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=True)
assert output.shape == (batch, seq_len, emb_size)
assert cache is not None
def test_cached_decoder_forward_no_cache(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 12, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=False)
assert output.shape == (batch, seq_len, emb_size)
assert cache is None
def test_cached_decoder_error_on_long_seq(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 1, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
with pytest.raises(ValueError):
model(x)
def test_cached_decoder_backward(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 7, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
output, cache = model(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
def test_cached_decoder_kv_cache_chain(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 1, 4, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
# Первый проход — кэша нет
_, cache = model(x, use_cache=True)
# Второй проход — передаём кэш, добавляем еще токен:
next_x = torch.randn(batch, 1, emb_size)
_, cache2 = model(next_x, use_cache=True, cache=cache)
assert cache2 is not None