diff --git a/llm/src/llm/models/mistral/mistral.py b/llm/src/llm/models/mistral/mistral.py index aae3985..1e56eea 100644 --- a/llm/src/llm/models/mistral/mistral.py +++ b/llm/src/llm/models/mistral/mistral.py @@ -11,6 +11,51 @@ from llm.core.mistral_decoder import MistralDecoder class Mistral(BaseModel): + """ + Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста. + + Назначение: + ----------- + - Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention. + - Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях. + - Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др. + + Архитектурные особенности: + -------------------------- + - Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding). + - Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти). + - Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral). + - SwiGLU FeedForward-блоки и RMSNorm. + - Dropout (регуляризация). + - Кэширование attention (KV cache) для быстрой генерации токенов по одному. + + Аргументы конструктора: + ----------------------- + config (dict): параметры модели (см. документацию Mistral): + vocab_size: int — размер словаря токенов + embed_dim: int — размерность эмбеддингов + num_q_heads: int — количество query-голов (обычно больше num_kv_heads) + num_kv_heads: int — количество key/value attention-голов + num_layers: int — число слоёв-декодеров + max_position_embeddings: int — максимальная длина последовательности + window_size: int — размер sliding window attention + dropout: float — dropout (обычно очень мал или 0) + ... + + Пример использования: + --------------------- + >>> model = Mistral({...}) + >>> tokens = torch.tensor([[100, 56, 8]]) + >>> logits = model(tokens) + >>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50) + + References: + ----------- + - "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825 + - LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288 + - Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral + + """ def __init__(self, config): super().__init__(config) @@ -43,6 +88,25 @@ class Mistral(BaseModel): self._linear = nn.Linear(config["embed_dim"], config["vocab_size"]) def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple: + """ + Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации. + + Аргументы: + x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов. + use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации. + cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша). + + Возвращает: + logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена. + new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False). + + Исключения: + ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш. + + Пример: + >>> logits, cache = model.forward(input_ids, use_cache=True) + >>> probabilities = torch.softmax(logits, dim=-1) + """ # Проверка длины последовательности (только при отсутствии кэша) if cache is None and x.size(1) > self._max_seq_len: raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}") @@ -85,6 +149,51 @@ class Mistral(BaseModel): top_p: float = None, use_cache: bool = True ) -> torch.Tensor: + """ + Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling + и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах). + + Аргументы: + x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len]. + max_new_tokens (int): Максимальное количество новых токенов для генерации. + do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax). + temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие. + top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов. + top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p. + use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима. + + Возвращает: + torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens]. + + Исключения: + ValueError: Если x длиннее max_seq_len модели. + ValueError: Если temperature ≤ 0. + ValueError: Если одновременно заданы top_k и top_p. + ValueError: Если top_k ≤ 0. + ValueError: Если top_p не в диапазоне (0, 1]. + + Примеры: + >>> # Жадная генерация + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False) + >>> # Сэмплирование с температурой + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8) + >>> # Top-k sampling + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50) + >>> # Top-p (nucleus) sampling + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92) + >>> # Температура + top-k + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100) + + Примечания: + - Одновременно использовать top_k и top_p нельзя. + - Параметры temperature, top_k, top_p работают только при do_sample=True. + - Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed. + - Метод всегда возвращает только индексы токенов; для получения логитов используйте forward. + + Ссылки: + - Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751 + - Mistral: https://arxiv.org/abs/2310.06825 + """ cache = None for _ in range(max_new_tokens):