docs(models): update and expand docstrings for Mistral and its methods

- docs: add comprehensive docstrings for the Mistral class (in Russian) and its methods (forward, generate)
- docs: explain model architecture (GQA, Sliding Window Attention, SwiGLU, RMSNorm, RoPE), arguments, constraints, generation modes, usage examples, and references (Mistral, nucleus sampling)
- strictly documentation improvements, no logic/API changes

This commit makes Mistral model documentation clear and user-friendly for LLM engineering and inference.
This commit is contained in:
Sergey Penkovsky
2025-10-16 17:01:57 +03:00
parent aec3c8adb6
commit 38c271ca3c

View File

@@ -11,6 +11,51 @@ from llm.core.mistral_decoder import MistralDecoder
class Mistral(BaseModel): class Mistral(BaseModel):
"""
Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста.
Назначение:
-----------
- Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention.
- Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях.
- Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др.
Архитектурные особенности:
--------------------------
- Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding).
- Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти).
- Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral).
- SwiGLU FeedForward-блоки и RMSNorm.
- Dropout (регуляризация).
- Кэширование attention (KV cache) для быстрой генерации токенов по одному.
Аргументы конструктора:
-----------------------
config (dict): параметры модели (см. документацию Mistral):
vocab_size: int — размер словаря токенов
embed_dim: int — размерность эмбеддингов
num_q_heads: int — количество query-голов (обычно больше num_kv_heads)
num_kv_heads: int — количество key/value attention-голов
num_layers: int — число слоёв-декодеров
max_position_embeddings: int — максимальная длина последовательности
window_size: int — размер sliding window attention
dropout: float — dropout (обычно очень мал или 0)
...
Пример использования:
---------------------
>>> model = Mistral({...})
>>> tokens = torch.tensor([[100, 56, 8]])
>>> logits = model(tokens)
>>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50)
References:
-----------
- "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825
- LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288
- Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral
"""
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
@@ -43,6 +88,25 @@ class Mistral(BaseModel):
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"]) self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple: def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации.
Аргументы:
x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов.
use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации.
cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша).
Возвращает:
logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена.
new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False).
Исключения:
ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш.
Пример:
>>> logits, cache = model.forward(input_ids, use_cache=True)
>>> probabilities = torch.softmax(logits, dim=-1)
"""
# Проверка длины последовательности (только при отсутствии кэша) # Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len: if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}") raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
@@ -85,6 +149,51 @@ class Mistral(BaseModel):
top_p: float = None, top_p: float = None,
use_cache: bool = True use_cache: bool = True
) -> torch.Tensor: ) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None cache = None
for _ in range(max_new_tokens): for _ in range(max_new_tokens):