mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(core): add docstrings and unit tests for MistralDecoder
- docs: expanded docstrings for MistralDecoder class and methods (__init__, forward); explained architecture, key parameters, usage, and links to relevant papers (Mistral, Llama 2) - test: add comprehensive unit tests for MistralDecoder (init, forward, cache handling, output shape, shape errors, backward) - These changes improve explainability, reliability, and test coverage for the decoder module.
This commit is contained in:
@@ -8,6 +8,44 @@ from llm.core.rope import RoPE
|
||||
from llm.core.group_query_attention import GroupedQueryAttention
|
||||
|
||||
class MistralDecoder(nn.Module):
|
||||
"""
|
||||
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
|
||||
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
|
||||
|
||||
Ключевые особенности архитектуры:
|
||||
---------------------------------
|
||||
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
|
||||
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
|
||||
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
|
||||
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
|
||||
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
num_layers : int — сколько блоков-декодеров в стеке
|
||||
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
|
||||
- все они идут в каждый слой (блок) декодера
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> decoder = MistralDecoder(
|
||||
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
|
||||
>>> x = torch.randn(2, 512, 256)
|
||||
>>> out, cache = decoder(x)
|
||||
>>> print(out.shape) # torch.Size([2, 512, 256])
|
||||
|
||||
Подробнее:
|
||||
----------
|
||||
- Mistral: https://arxiv.org/abs/2310.06825
|
||||
- Llama 2: https://arxiv.org/abs/2307.09288
|
||||
- Open LLM обзор: https://huggingface.co/blog/mistral
|
||||
|
||||
"""
|
||||
def __init__(self,
|
||||
num_q_heads: int,
|
||||
num_kv_heads: int,
|
||||
@@ -18,6 +56,35 @@ class MistralDecoder(nn.Module):
|
||||
rope: RoPE,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Инициализация стека декодеров MistralDecoder.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
num_layers : int
|
||||
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
|
||||
num_q_heads : int
|
||||
Количество Query-heads в attention (их больше, экономит память).
|
||||
num_kv_heads : int
|
||||
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
|
||||
emb_size : int
|
||||
Размерность embedding (должна делиться на num_q_heads без остатка).
|
||||
head_size : int
|
||||
Размер одного attention head.
|
||||
max_seq_len : int
|
||||
Максимально обрабатываемая длина последовательности.
|
||||
window_size : int
|
||||
Размер окна для sliding window attention.
|
||||
rope : RoPE
|
||||
Rotary Positional Embedding для Q/K.
|
||||
dropout : float, опционально
|
||||
Dropout на каждом attention/FFN (по умолчанию 0.1).
|
||||
|
||||
Внутри:
|
||||
-------
|
||||
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
|
||||
- Все параметры передаются в каждый слой (блок).
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = GroupedQueryAttention(
|
||||
num_q_heads=num_q_heads,
|
||||
@@ -34,6 +101,26 @@ class MistralDecoder(nn.Module):
|
||||
self._norm2 = RMSNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через стек MistralDecoder.
|
||||
|
||||
Аргументы:
|
||||
----------
|
||||
x : torch.Tensor
|
||||
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
|
||||
use_cache : bool, по умолчанию True
|
||||
Включить ли кэширование для ускорения генерации (авторегрессия).
|
||||
cache : list, опционально
|
||||
Предыдущий кеш attention-блоков (или None).
|
||||
|
||||
Возвращает:
|
||||
-----------
|
||||
out : torch.Tensor
|
||||
Тензор после декодирования (shape соответствует x).
|
||||
new_cache : list (или None)
|
||||
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
|
||||
|
||||
"""
|
||||
norm1_out = self._norm1(x)
|
||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||
out = attention + x
|
||||
|
||||
66
llm/tests/core/test_mistral_decoder.py
Normal file
66
llm/tests/core/test_mistral_decoder.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.core.mistral_decoder import MistralDecoder
|
||||
from llm.core.rope import RoPE
|
||||
|
||||
@pytest.fixture
|
||||
def decoder_config():
|
||||
# Current MistralDecoder is a single block (not a stack).
|
||||
return dict(
|
||||
num_q_heads=4,
|
||||
num_kv_heads=2,
|
||||
emb_size=32,
|
||||
head_size=8,
|
||||
max_seq_len=128,
|
||||
window_size=16,
|
||||
rope=RoPE(head_size=8, max_seq_len=128),
|
||||
dropout=0.0
|
||||
)
|
||||
|
||||
def test_mistral_decoder_init(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
assert model is not None
|
||||
|
||||
def test_mistral_decoder_forward_shapes(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=True)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is not None
|
||||
|
||||
def test_mistral_decoder_forward_no_cache(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
output, cache = model(x, use_cache=False)
|
||||
assert output.shape == (batch, seq_len, emb_size)
|
||||
assert cache is None
|
||||
|
||||
def test_mistral_decoder_cache_shapes(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
# Первый проход — без кэша
|
||||
_, cache = model(x, use_cache=True)
|
||||
# Второй проход — заполняем кэш
|
||||
x_next = torch.randn(batch, 1, emb_size)
|
||||
_, cache2 = model(x_next, use_cache=True, cache=cache)
|
||||
# Можно проверить, что кэш не None и корректной структуры:
|
||||
assert cache2 is not None
|
||||
|
||||
def test_mistral_decoder_shape_error(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size)
|
||||
with pytest.raises(ValueError):
|
||||
model(x)
|
||||
|
||||
def test_mistral_decoder_backward(decoder_config):
|
||||
model = MistralDecoder(**decoder_config)
|
||||
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||
output, _ = model(x, use_cache=False)
|
||||
loss = output.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
Reference in New Issue
Block a user