docs(core): add docstrings and unit tests for MistralDecoder

- docs: expanded docstrings for MistralDecoder class and methods (__init__, forward); explained architecture, key parameters, usage, and links to relevant papers (Mistral, Llama 2)
- test: add comprehensive unit tests for MistralDecoder (init, forward, cache handling, output shape, shape errors, backward)
- These changes improve explainability, reliability, and test coverage for the decoder module.
This commit is contained in:
Sergey Penkovsky
2025-10-15 18:07:11 +03:00
parent e6ca8dee6f
commit ba3b04cec2
2 changed files with 153 additions and 0 deletions

View File

@@ -8,6 +8,44 @@ from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
class MistralDecoder(nn.Module):
"""
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
Назначение:
-----------
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
Ключевые особенности архитектуры:
---------------------------------
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
Аргументы конструктора:
-----------------------
num_layers : int — сколько блоков-декодеров в стеке
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
- все они идут в каждый слой (блок) декодера
Пример использования:
---------------------
>>> decoder = MistralDecoder(
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
>>> x = torch.randn(2, 512, 256)
>>> out, cache = decoder(x)
>>> print(out.shape) # torch.Size([2, 512, 256])
Подробнее:
----------
- Mistral: https://arxiv.org/abs/2310.06825
- Llama 2: https://arxiv.org/abs/2307.09288
- Open LLM обзор: https://huggingface.co/blog/mistral
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
@@ -18,6 +56,35 @@ class MistralDecoder(nn.Module):
rope: RoPE,
dropout: float = 0.1
):
"""
Инициализация стека декодеров MistralDecoder.
Аргументы:
----------
num_layers : int
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
num_q_heads : int
Количество Query-heads в attention (их больше, экономит память).
num_kv_heads : int
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
emb_size : int
Размерность embedding (должна делиться на num_q_heads без остатка).
head_size : int
Размер одного attention head.
max_seq_len : int
Максимально обрабатываемая длина последовательности.
window_size : int
Размер окна для sliding window attention.
rope : RoPE
Rotary Positional Embedding для Q/K.
dropout : float, опционально
Dropout на каждом attention/FFN (по умолчанию 0.1).
Внутри:
-------
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
- Все параметры передаются в каждый слой (блок).
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
@@ -34,6 +101,26 @@ class MistralDecoder(nn.Module):
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход через стек MistralDecoder.
Аргументы:
----------
x : torch.Tensor
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
use_cache : bool, по умолчанию True
Включить ли кэширование для ускорения генерации (авторегрессия).
cache : list, опционально
Предыдущий кеш attention-блоков (или None).
Возвращает:
-----------
out : torch.Tensor
Тензор после декодирования (shape соответствует x).
new_cache : list (или None)
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x

View File

@@ -0,0 +1,66 @@
import torch
import pytest
from llm.core.mistral_decoder import MistralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def decoder_config():
# Current MistralDecoder is a single block (not a stack).
return dict(
num_q_heads=4,
num_kv_heads=2,
emb_size=32,
head_size=8,
max_seq_len=128,
window_size=16,
rope=RoPE(head_size=8, max_seq_len=128),
dropout=0.0
)
def test_mistral_decoder_init(decoder_config):
model = MistralDecoder(**decoder_config)
assert model is not None
def test_mistral_decoder_forward_shapes(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=True)
assert output.shape == (batch, seq_len, emb_size)
assert cache is not None
def test_mistral_decoder_forward_no_cache(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=False)
assert output.shape == (batch, seq_len, emb_size)
assert cache is None
def test_mistral_decoder_cache_shapes(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
# Первый проход — без кэша
_, cache = model(x, use_cache=True)
# Второй проход — заполняем кэш
x_next = torch.randn(batch, 1, emb_size)
_, cache2 = model(x_next, use_cache=True, cache=cache)
# Можно проверить, что кэш не None и корректной структуры:
assert cache2 is not None
def test_mistral_decoder_shape_error(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
with pytest.raises(ValueError):
model(x)
def test_mistral_decoder_backward(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
output, _ = model(x, use_cache=False)
loss = output.sum()
loss.backward()
assert x.grad is not None