mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(models): update and expand docstrings for Mistral and its methods
- docs: add comprehensive docstrings for the Mistral class (in Russian) and its methods (forward, generate) - docs: explain model architecture (GQA, Sliding Window Attention, SwiGLU, RMSNorm, RoPE), arguments, constraints, generation modes, usage examples, and references (Mistral, nucleus sampling) - strictly documentation improvements, no logic/API changes This commit makes Mistral model documentation clear and user-friendly for LLM engineering and inference.
This commit is contained in:
@@ -11,6 +11,51 @@ from llm.core.mistral_decoder import MistralDecoder
|
|||||||
|
|
||||||
|
|
||||||
class Mistral(BaseModel):
|
class Mistral(BaseModel):
|
||||||
|
"""
|
||||||
|
Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention.
|
||||||
|
- Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях.
|
||||||
|
- Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др.
|
||||||
|
|
||||||
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding).
|
||||||
|
- Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти).
|
||||||
|
- Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral).
|
||||||
|
- SwiGLU FeedForward-блоки и RMSNorm.
|
||||||
|
- Dropout (регуляризация).
|
||||||
|
- Кэширование attention (KV cache) для быстрой генерации токенов по одному.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
config (dict): параметры модели (см. документацию Mistral):
|
||||||
|
vocab_size: int — размер словаря токенов
|
||||||
|
embed_dim: int — размерность эмбеддингов
|
||||||
|
num_q_heads: int — количество query-голов (обычно больше num_kv_heads)
|
||||||
|
num_kv_heads: int — количество key/value attention-голов
|
||||||
|
num_layers: int — число слоёв-декодеров
|
||||||
|
max_position_embeddings: int — максимальная длина последовательности
|
||||||
|
window_size: int — размер sliding window attention
|
||||||
|
dropout: float — dropout (обычно очень мал или 0)
|
||||||
|
...
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> model = Mistral({...})
|
||||||
|
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||||
|
>>> logits = model(tokens)
|
||||||
|
>>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825
|
||||||
|
- LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288
|
||||||
|
- Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
@@ -43,6 +88,25 @@ class Mistral(BaseModel):
|
|||||||
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
|
||||||
|
"""
|
||||||
|
Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов.
|
||||||
|
use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации.
|
||||||
|
cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша).
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена.
|
||||||
|
new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False).
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> logits, cache = model.forward(input_ids, use_cache=True)
|
||||||
|
>>> probabilities = torch.softmax(logits, dim=-1)
|
||||||
|
"""
|
||||||
# Проверка длины последовательности (только при отсутствии кэша)
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
if cache is None and x.size(1) > self._max_seq_len:
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
|
||||||
@@ -85,6 +149,51 @@ class Mistral(BaseModel):
|
|||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True
|
use_cache: bool = True
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||||
|
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||||
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
|
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
|
||||||
|
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
|
||||||
|
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
|
||||||
|
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
|
||||||
|
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
|
||||||
|
|
||||||
|
Исключения:
|
||||||
|
ValueError: Если x длиннее max_seq_len модели.
|
||||||
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры:
|
||||||
|
>>> # Жадная генерация
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||||
|
>>> # Сэмплирование с температурой
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
>>> # Top-p (nucleus) sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||||
|
>>> # Температура + top-k
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- Одновременно использовать top_k и top_p нельзя.
|
||||||
|
- Параметры temperature, top_k, top_p работают только при do_sample=True.
|
||||||
|
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
|
||||||
|
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
|
||||||
|
- Mistral: https://arxiv.org/abs/2310.06825
|
||||||
|
"""
|
||||||
cache = None
|
cache = None
|
||||||
|
|
||||||
for _ in range(max_new_tokens):
|
for _ in range(max_new_tokens):
|
||||||
|
|||||||
Reference in New Issue
Block a user