mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
docs(models): update and expand docstrings for LLaMA and generate method
- docs: add full, detailed Russian-language docstring for LLaMA.generate (sampling, top-k/top-p, examples, all parameter constraints and references) - docs: bring LLaMA class header in line with modern LLM doc practices (motivation, architecture, references) - no changes to logic, API, or tests This makes the LLaMA model documentation fully transparent for all generation and inference modes.
This commit is contained in:
@@ -12,24 +12,62 @@ from llm.core.cached_decoder import CachedDecoder
|
|||||||
|
|
||||||
class Llama(BaseModel):
|
class Llama(BaseModel):
|
||||||
"""
|
"""
|
||||||
LLaMA (Large Language Model Meta AI) — высокоэффективная масштабируемая языковая модель, разработанная Meta AI Research.
|
LLaMA — автогерессивная большая языковая модель (Large Language Model from Meta, 2023).
|
||||||
|
|
||||||
Ключевые идеи:
|
Назначение:
|
||||||
- Rotary Positional Encoding (RoPE) вместо стандартных позиционных эмбеддингов
|
-----------
|
||||||
- RMSNorm (Root Mean Square LayerNorm) вместо LayerNorm
|
- Модель реализует архитектуру decoder-only Transformer с современными "индустриальными" трюками (RMSNorm, SwiGLU, RoPE, GQA).
|
||||||
- SwiGLU как нелинейность вместо ReLU/GELU (больше экспрессивности)
|
- Предназначена для генерации текста, чат-ботов, zero-/few-shot вывода, fine-tune в стиле RLHF, transfer learning и исследований в LLM.
|
||||||
- Глубокая оптимизация inference (большая экономия памяти и FLOPs)
|
|
||||||
Подробнее: https://arxiv.org/abs/2302.13971
|
Архитектурные особенности:
|
||||||
|
--------------------------
|
||||||
|
- Токеновые эмбеддинги и позиционное кодирование с помощью Rotary Position Embedding (RoPE, https://arxiv.org/abs/2104.09864).
|
||||||
|
- Stack из num_layers современных декодеров с Grouped Query Attention (GQA: num_q_heads > num_kv_heads) для эффективной генерации.
|
||||||
|
- FeedForward блоки с SwiGLU (см. https://arxiv.org/abs/2002.05202).
|
||||||
|
- Нормализация RMSNorm перед каждым sub-layer (вот почему "Pre-RMSNorm").
|
||||||
|
- Кэширование attention (KV cache) для быстрой autoregressive генерации.
|
||||||
|
- Нет bias в Linear слоях, нет Dropout внутри attention.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
config: dict с требуемыми ключами:
|
||||||
|
vocab_size: int — размер словаря токенов
|
||||||
|
embed_dim: int — размерность эмбеддингов
|
||||||
|
num_q_heads: int — количество query-голов в attention (обычно больше num_kv_heads)
|
||||||
|
num_kv_heads: int — количество key/value-голов
|
||||||
|
num_layers: int — число слоёв-декодеров
|
||||||
|
max_position_embeddings: int — максимальная длина последовательности
|
||||||
|
window_size: int (optional) — размер sliding window для attention
|
||||||
|
dropout: float (обычно 0.0 или очень мал)
|
||||||
|
...
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> llama = LLaMA({...})
|
||||||
|
>>> tokens = torch.tensor([[100, 56, 8]])
|
||||||
|
>>> logits = llama(tokens)
|
||||||
|
>>> out = llama.generate(tokens, max_new_tokens=10, do_sample=True, top_k=50)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- "LLaMA: Open and Efficient Foundation Language Models" (Touvron et al., 2023): https://arxiv.org/abs/2302.13971
|
||||||
|
- "Grouped-Query Attention": https://arxiv.org/abs/2307.09288
|
||||||
|
- "RoFormer: Enhanced Transformer with Rotary Position Embedding": https://arxiv.org/abs/2104.09864
|
||||||
|
- Discussion of efficient LLMs: https://huggingface.co/blog/mistral
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
|
|
||||||
Пример:
|
|
||||||
>>> model = Llama({...})
|
|
||||||
>>> logits, cache = model(input_ids, use_cache=True)
|
|
||||||
>>> out = model.generate(input_ids, max_new_tokens=20)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Инициализация LLaMA.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config (dict): Параметры архитектуры, см. docstring класса.
|
||||||
|
Внутри:
|
||||||
|
-------
|
||||||
|
- Создаёт Embedding-слой, Rotary Position Embeddings (RoPE), стек слоёв с GQA, RMSNorm, SwiGLU.
|
||||||
|
- Финальный слой нормализации и проекции на vocabulary.
|
||||||
|
"""
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|
||||||
# Инициализация слоев
|
# Инициализация слоев
|
||||||
@@ -68,17 +106,16 @@ class Llama(BaseModel):
|
|||||||
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
|
||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Прямой проход через LLaMA (inference/train): авторегрессионное предсказание токенов.
|
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (Tensor[int]): входные токены [batch, seq_len]
|
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
|
||||||
use_cache (bool): использовать ли кэш (ускоряет генерацию)
|
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
|
||||||
cache (list|None): ключи и значения attention для autoregressive режима
|
cache (list or None): предыдущий кэш, если нужен
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits (Tensor): [batch, seq_len, vocab_size]
|
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||||
new_cache (list|None): новый кэш attention (если use_cache)
|
new_cache: новый кэш attention (или None)
|
||||||
Пример:
|
|
||||||
>>> logits, cache = model.forward(x, use_cache=True)
|
|
||||||
"""
|
"""
|
||||||
# Проверка длины последовательности (только при отсутствии кэша)
|
# Проверка длины последовательности (только при отсутствии кэша)
|
||||||
if cache is None and x.size(1) > self._max_seq_len:
|
if cache is None and x.size(1) > self._max_seq_len:
|
||||||
|
|||||||
Reference in New Issue
Block a user