mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(core): add docstrings and unit tests for MistralDecoder
- docs: expanded docstrings for MistralDecoder class and methods (__init__, forward); explained architecture, key parameters, usage, and links to relevant papers (Mistral, Llama 2) - test: add comprehensive unit tests for MistralDecoder (init, forward, cache handling, output shape, shape errors, backward) - These changes improve explainability, reliability, and test coverage for the decoder module.
This commit is contained in:
@@ -8,6 +8,44 @@ from llm.core.rope import RoPE
|
|||||||
from llm.core.group_query_attention import GroupedQueryAttention
|
from llm.core.group_query_attention import GroupedQueryAttention
|
||||||
|
|
||||||
class MistralDecoder(nn.Module):
|
class MistralDecoder(nn.Module):
|
||||||
|
"""
|
||||||
|
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
|
||||||
|
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
|
||||||
|
|
||||||
|
Ключевые особенности архитектуры:
|
||||||
|
---------------------------------
|
||||||
|
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
|
||||||
|
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
|
||||||
|
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
|
||||||
|
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
|
||||||
|
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
num_layers : int — сколько блоков-декодеров в стеке
|
||||||
|
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
|
||||||
|
- все они идут в каждый слой (блок) декодера
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> decoder = MistralDecoder(
|
||||||
|
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
|
||||||
|
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
|
||||||
|
>>> x = torch.randn(2, 512, 256)
|
||||||
|
>>> out, cache = decoder(x)
|
||||||
|
>>> print(out.shape) # torch.Size([2, 512, 256])
|
||||||
|
|
||||||
|
Подробнее:
|
||||||
|
----------
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
- Llama 2: https://arxiv.org/abs/2307.09288
|
||||||
|
- Open LLM обзор: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
num_q_heads: int,
|
num_q_heads: int,
|
||||||
num_kv_heads: int,
|
num_kv_heads: int,
|
||||||
@@ -18,6 +56,35 @@ class MistralDecoder(nn.Module):
|
|||||||
rope: RoPE,
|
rope: RoPE,
|
||||||
dropout: float = 0.1
|
dropout: float = 0.1
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Инициализация стека декодеров MistralDecoder.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
num_layers : int
|
||||||
|
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
|
||||||
|
num_q_heads : int
|
||||||
|
Количество Query-heads в attention (их больше, экономит память).
|
||||||
|
num_kv_heads : int
|
||||||
|
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
|
||||||
|
emb_size : int
|
||||||
|
Размерность embedding (должна делиться на num_q_heads без остатка).
|
||||||
|
head_size : int
|
||||||
|
Размер одного attention head.
|
||||||
|
max_seq_len : int
|
||||||
|
Максимально обрабатываемая длина последовательности.
|
||||||
|
window_size : int
|
||||||
|
Размер окна для sliding window attention.
|
||||||
|
rope : RoPE
|
||||||
|
Rotary Positional Embedding для Q/K.
|
||||||
|
dropout : float, опционально
|
||||||
|
Dropout на каждом attention/FFN (по умолчанию 0.1).
|
||||||
|
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
|
||||||
|
- Все параметры передаются в каждый слой (блок).
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._heads = GroupedQueryAttention(
|
self._heads = GroupedQueryAttention(
|
||||||
num_q_heads=num_q_heads,
|
num_q_heads=num_q_heads,
|
||||||
@@ -34,6 +101,26 @@ class MistralDecoder(nn.Module):
|
|||||||
self._norm2 = RMSNorm(emb_size)
|
self._norm2 = RMSNorm(emb_size)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход через стек MistralDecoder.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
----------
|
||||||
|
x : torch.Tensor
|
||||||
|
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
|
||||||
|
use_cache : bool, по умолчанию True
|
||||||
|
Включить ли кэширование для ускорения генерации (авторегрессия).
|
||||||
|
cache : list, опционально
|
||||||
|
Предыдущий кеш attention-блоков (или None).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
-----------
|
||||||
|
out : torch.Tensor
|
||||||
|
Тензор после декодирования (shape соответствует x).
|
||||||
|
new_cache : list (или None)
|
||||||
|
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
|
||||||
|
|
||||||
|
"""
|
||||||
norm1_out = self._norm1(x)
|
norm1_out = self._norm1(x)
|
||||||
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
|
||||||
out = attention + x
|
out = attention + x
|
||||||
|
|||||||
66
llm/tests/core/test_mistral_decoder.py
Normal file
66
llm/tests/core/test_mistral_decoder.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.core.mistral_decoder import MistralDecoder
|
||||||
|
from llm.core.rope import RoPE
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def decoder_config():
|
||||||
|
# Current MistralDecoder is a single block (not a stack).
|
||||||
|
return dict(
|
||||||
|
num_q_heads=4,
|
||||||
|
num_kv_heads=2,
|
||||||
|
emb_size=32,
|
||||||
|
head_size=8,
|
||||||
|
max_seq_len=128,
|
||||||
|
window_size=16,
|
||||||
|
rope=RoPE(head_size=8, max_seq_len=128),
|
||||||
|
dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_mistral_decoder_init(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
assert model is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_forward_shapes(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=True)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_forward_no_cache(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output, cache = model(x, use_cache=False)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
assert cache is None
|
||||||
|
|
||||||
|
def test_mistral_decoder_cache_shapes(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
# Первый проход — без кэша
|
||||||
|
_, cache = model(x, use_cache=True)
|
||||||
|
# Второй проход — заполняем кэш
|
||||||
|
x_next = torch.randn(batch, 1, emb_size)
|
||||||
|
_, cache2 = model(x_next, use_cache=True, cache=cache)
|
||||||
|
# Можно проверить, что кэш не None и корректной структуры:
|
||||||
|
assert cache2 is not None
|
||||||
|
|
||||||
|
def test_mistral_decoder_shape_error(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
model(x)
|
||||||
|
|
||||||
|
def test_mistral_decoder_backward(decoder_config):
|
||||||
|
model = MistralDecoder(**decoder_config)
|
||||||
|
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
|
||||||
|
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
|
||||||
|
output, _ = model(x, use_cache=False)
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
Reference in New Issue
Block a user