Files
llm-arch-research/notebooks/llama.ipynb

666 lines
30 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "efbc675e",
"metadata": {},
"source": [
"# Llama\n",
"\n",
"![](https://ucarecdn.com/05af6071-73b2-4067-9a39-632fcb2f24e9/)\n",
"\n",
"\n",
"Llama 1 вышла в феврале 2023 года. Это уже подальше, чем GPT-2. И в ее архитектуре появилось уже больше серьезных изменений:\n",
"\n",
"- Нормализация RMSNorm (вместе с pre-norm).\n",
"- Функция активации SwiGLU.\n",
"- Новый способ кодирования позиций — Rotary Positional Embeddings."
]
},
{
"cell_type": "markdown",
"id": "2cedc663",
"metadata": {},
"source": [
"# RMSNorm\n",
"\n",
"![2\\_rmsnorm.png](https://ucarecdn.com/2975a217-27ff-4d26-b4a1-cc48a8de1e45/)\n",
"\n",
"В Llama используется более быстрая и эффективная нормализация — **RMSNorm (Root Mean Square Normalization)**.\n",
"И, также как в GPT-2, используется *pre-norm* нормализация, то есть слои нормализации располагаются **перед блоками внимания и FNN**.\n",
"\n",
"RMSNorm отличается от обычной нормализации только одним: в нём исключен этап центрирования (вычитание среднего) и используется только масштабирование по RMS.\n",
"Это сокращает вычислительные затраты (на 764%) без существенной потери качества.\n",
"На картинке показана разница в распределении после применения RMSNorm и LayerNorm к исходным данным — RMSNorm не разбросан вокруг нуля.\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/cbfbb78e-e2b0-40e2-ba56-73e5114d54f6/\" width=\"350\" alt=\"RMSNorm vs LayerNorm\">\n",
"</p>\n",
"\n",
"## Этапы вычисления RMSNorm\n",
"\n",
"1. **Вычисление среднеквадратичного значения:**\n",
"\n",
" $$\\text{RMS}(\\mathbf{x}) = \\sqrt{\\frac{1}{d} \\sum_{j=1}^{d} x_j^2}$$\n",
"\n",
"2. **Нормализация входящего вектора:**\n",
"\n",
" $$\\hat{x}_i = \\frac{x_i}{\\text{RMS}(\\mathbf{x})}$$\n",
"\n",
"3. **Применение масштабирования:**\n",
"\n",
" $$y_i = w_i \\cdot \\hat{x}_i$$\n",
"\n",
"---\n",
"\n",
"**Где:**\n",
"\n",
"* $x_i$ — *i*-й элемент входящего вектора.\n",
"* $w_i$ — *i*-й элемент обучаемого вектора весов.\n",
" Использование весов позволяет модели адаптивно регулировать амплитуду признаков.\n",
" Без них нормализация была бы слишком «жёсткой» и могла бы ограничить качество модели.\n",
"* $d$ — размерность входящего вектора.\n",
"* $\\varepsilon$ — малая константа (например, 1e-6), предотвращает деление на ноль.\n",
"\n",
"---\n",
"\n",
"Так как на вход подаётся тензор, то в векторной форме RMSNorm вычисляется так:\n",
"\n",
"$$\n",
"RMSNorm(x) = w ⊙ \\frac{x}{\\sqrt{mean(x^2) + ϵ}}\n",
"$$\n",
"\n",
"**Где:**\n",
"\n",
"* $x$ — входящий тензор размера `batch_size × ...`\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "873704be",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6):\n",
" super().__init__()\n",
" self._eps = eps\n",
" self._w = nn.Parameter(torch.ones(dim))\n",
" \n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n",
" norm_x = x / rms\n",
" return self._w * norm_x"
]
},
{
"cell_type": "markdown",
"id": "09dd9625",
"metadata": {},
"source": [
"# SwiGLU\n",
"\n",
"![3\\_swiglu.png](https://ucarecdn.com/120dea77-2bf2-455f-9b54-c35c4efddc9e/)\n",
"\n",
"В **Llama** ввели новую функцию активации — **SwiGLU (Swish-Gated Linear Unit)** — это гибридная функция активации, которая представляет собой комбинацию трёх линейных преобразований и функции активации **SiLU (Sigmoid Linear Unit)**, она же *Swish* в терминологии Google.\n",
"\n",
"Формула SwiGLU выглядит так:\n",
"\n",
"$$\n",
"\\text{SwiGLU}(x) = \\text{down}(\\text{SiLU}(\\text{gate}(x)) \\otimes \\text{up}(x))\n",
"$$\n",
"\n",
"где:\n",
"\n",
"* $x$ — входящий тензор.\n",
"* $\\text{gate}(x)$ — линейный слой для гейтового механизма. Преобразует вход `x` размерностью `emb_size` в промежуточное представление размерности `4 * emb_size`.\n",
"* $\\text{up}(x)$ — линейный слой для увеличения размерности. Также преобразует `x` в размерность `4 * emb_size`.\n",
"* $\\text{SiLU}(x) = x \\cdot \\sigma(x)$ — функция активации, где $\\sigma$ — сигмоида.\n",
"* $\\otimes$ — поэлементное умножение.\n",
"* $\\text{down}(x)$ — линейный слой для уменьшения промежуточного представления до исходного размера (`emb_size`).\n",
"\n",
"> **Гейтинг** (от слова *gate* — «врата») — это механизм, который позволяет сети динамически фильтровать, какая информация должна проходить дальше.\n",
"> При гейтинге создаются как бы два независимых потока:\n",
">\n",
"> * один предназначен для прямой передачи информации (*up-down*),\n",
"> * другой — для контроля передаваемой информации (*gate*).\n",
">\n",
"> Это позволяет сети учить более сложные паттерны.\n",
"> Например, гейт может научиться:\n",
"> «если признак A активен, то пропусти признак B»,\n",
"> что невозможно с простой функцией активации между линейными слоями.\n",
">\n",
"> Также гейтинг помогает с затуханием градиентов: вместо простого обнуления (как в ReLU), гейт может тонко модулировать силу сигнала.\n",
"\n",
"SwiGLU более сложная (дорогая), чем ReLU/GELU, так как требует больше вычислений (три линейных преобразования вместо двух).\n",
"Но при этом показывает лучшее качество по сравнению с ReLU и GELU.\n",
"\n",
"График **SiLU** похож на **GELU**, но более гладкий:\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/6683e0c8-96b7-4389-826a-a73708b4a835/\" width=\"500\" alt=\"SiLU vs GELU\">\n",
"</p>\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0484cf77",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class SiLU(nn.Module):\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" return torch.sigmoid(x) * x"
]
},
{
"cell_type": "markdown",
"id": "0b64da5d",
"metadata": {},
"source": [
"## SwiGLU"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74ca39ba",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"class SwiGLU(nn.Module):\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
"\n",
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
" self._actvation = SiLU()\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
" out = up_out * activation_out # поэлементное!\n",
" out = self._down(out) # [batch, seq, emb]\n",
" return self._dropout(out)\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"id": "c386a55c",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "c9fe652d",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe1274b1",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from math import sqrt\n",
"import torch\n",
"from torch import nn\n",
"from torch import Tensor\n",
"\n",
"class TokenEmbeddings(nn.Module):\n",
" def __init__(self, vocab_size: int, emb_size: int):\n",
" super().__init__()\n",
" self._embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" return self._embedding(x)\n",
"\n",
" @property\n",
" def num_embeddings(self) -> int:\n",
" return self._embedding.num_embeddings\n",
"\n",
" @property\n",
" def embedding_dim(self) -> int:\n",
" return self._embedding.embedding_dim\n",
"\n",
"\n",
"import torch\n",
"from torch import nn, Tensor\n",
"\n",
"class PositionalEmbeddings(nn.Module):\n",
" def __init__(self, max_seq_len: int, emb_size: int):\n",
" super().__init__()\n",
" self.max_seq_len = max_seq_len\n",
" self.emb_size = emb_size\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=max_seq_len,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:\n",
" if seq_len < 1 or seq_len > self.max_seq_len:\n",
" raise IndexError(f\"Длина {seq_len} должна быть от 1 до {self.max_seq_len}\")\n",
" if start_pos == 0:\n",
" positions = torch.arange(seq_len, device=self.embedding.weight.device)\n",
" else:\n",
" positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)\n",
" return self.embedding(positions)\n",
" \n",
" \n",
"class HeadAttention(nn.Module):\n",
"\n",
" def __init__(self, emb_size: int, head_size: int, max_seq_len: int):\n",
" super().__init__()\n",
" self._emb_size = emb_size\n",
" self._head_size = head_size\n",
" self._max_seq_len = max_seq_len\n",
"\n",
" self._k = nn.Linear(emb_size, head_size)\n",
" self._q = nn.Linear(emb_size, head_size)\n",
" self._v = nn.Linear(emb_size, head_size)\n",
"\n",
" mask = torch.tril(torch.ones(max_seq_len, max_seq_len))\n",
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
" seq_len = x.shape[1]\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
"\n",
" k = self._k(x) # [B, T, hs]\n",
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]\n",
" v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]\n",
" \n",
" scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)\n",
" \n",
" if cache is None:\n",
" scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))\n",
" \n",
" weights = F.softmax(scores, dim=-1)\n",
" x_out = weights @ v # [B, T, hs]\n",
"\n",
" if use_cache is True:\n",
" return (x_out, (k, v))\n",
" else:\n",
" return (x_out, None)\n",
" \n",
"from torch import nn\n",
"import torch\n",
"import math\n",
"\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):\n",
"\n",
" super().__init__()\n",
" self._heads = nn.ModuleList([\n",
" HeadAttention(\n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len\n",
" ) for _ in range(num_heads)\n",
" ])\n",
" self._layer = nn.Linear(head_size * num_heads, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):\n",
"\n",
" attention_results = []\n",
" for i, head in enumerate(self._heads):\n",
" head_cache = cache[i] if cache is not None else None\n",
" result = head(x, use_cache=use_cache, cache=head_cache)\n",
" attention_results.append(result)\n",
" \n",
" outputs, caches = zip(*attention_results)\n",
" attention_outputs = list(outputs)\n",
" kv_caches = list(caches)\n",
" \n",
" concatenated_attention = torch.cat(attention_outputs, dim=-1)\n",
"\n",
" projected_output = self._layer(concatenated_attention)\n",
" \n",
" final_output = self._dropout(projected_output)\n",
" \n",
" if use_cache is True:\n",
" return (final_output, kv_caches)\n",
" else:\n",
" return (final_output, None)\n",
"\n",
"\n",
"class GELU(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n",
" \n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return 0.5 * x * (1 + torch.tanh(\n",
" self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n",
" ))\n",
"\n",
"class FeedForward(nn.Module):\n",
"\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
" self._layer1 = nn.Linear(emb_size, emb_size * 4)\n",
" self._gelu = GELU()\n",
" self._layer2 = nn.Linear(emb_size * 4, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" input_dtype = x.dtype\n",
" \n",
" if input_dtype != self._layer1.weight.dtype:\n",
" self._layer1 = self._layer1.to(dtype=input_dtype)\n",
" self._layer2 = self._layer2.to(dtype=input_dtype)\n",
" \n",
" x = self._layer1(x)\n",
" x = self._gelu(x)\n",
" x = self._layer2(x)\n",
" return self._dropout(x)\n",
" \n",
"class Decoder(nn.Module):\n",
" def __init__(self, \n",
" num_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" dropout: float = 0.1\n",
" ):\n",
" super().__init__()\n",
" self._heads = MultiHeadAttention(\n",
" num_heads=num_heads, \n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len, \n",
" dropout=dropout\n",
" )\n",
" self._ff = FeedForward(emb_size=emb_size, dropout=dropout)\n",
" self._norm1 = RMSNorm(emb_size)\n",
" self._norm2 = RMSNorm(emb_size)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n",
" norm1_out = self._norm1(x)\n",
" attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n",
" out = attention + x\n",
" \n",
" norm2_out = self._norm2(out)\n",
" ffn_out = self._ff(norm2_out)\n",
"\n",
" if use_cache is True:\n",
" return (ffn_out + out, kv_caches)\n",
" else:\n",
" return (ffn_out + out, None)\n",
"\n",
"\n",
"\n",
"from torch import nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"class Llama(nn.Module):\n",
" def __init__(self,\n",
" vocab_size: int,\n",
" max_seq_len: int,\n",
" emb_size: int,\n",
" num_heads: int,\n",
" head_size: int,\n",
" num_layers: int,\n",
" dropout: float = 0.1,\n",
" device: str = 'cpu'\n",
" ):\n",
" super().__init__()\n",
" self._vocab_size = vocab_size\n",
" self._max_seq_len = max_seq_len\n",
" self._emb_size = emb_size\n",
" self._num_heads = num_heads\n",
" self._head_size = head_size\n",
" self._num_layers = num_layers\n",
" self._dropout = dropout\n",
" self._device = device\n",
" \n",
" self.validation_loss = None\n",
"\n",
" # Инициализация слоев\n",
" self._token_embeddings = TokenEmbeddings(\n",
" vocab_size=vocab_size, \n",
" emb_size=emb_size\n",
" )\n",
" self._position_embeddings = PositionalEmbeddings(\n",
" max_seq_len=max_seq_len, \n",
" emb_size=emb_size\n",
" )\n",
" self._dropout = nn.Dropout(dropout)\n",
" self._decoders = nn.ModuleList([Decoder(\n",
" num_heads=num_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" dropout=dropout \n",
" ) for _ in range(num_layers)])\n",
" self._norm = RMSNorm(emb_size)\n",
" self._linear = nn.Linear(emb_size, vocab_size)\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n",
" # Проверка длины последовательности (только при отсутствии кэша)\n",
" if cache is None and x.size(1) > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n",
" \n",
" \n",
" # Вычисление start_pos из кэша (если кэш передан)\n",
" if cache is not None:\n",
" # При кэше обрабатываем только один токен (последний)\n",
" seq_len = 1\n",
" # Вычисляем start_pos из самого нижнего уровня кэша\n",
" if cache and cache[0] and cache[0][0]:\n",
" key_cache, _ = cache[0][0] # Первый декодер, первая голова\n",
" start_pos = key_cache.size(1) # cache_len\n",
" else:\n",
" start_pos = 0\n",
" else:\n",
" # Без кэша работаем как раньше\n",
" start_pos = 0\n",
" seq_len = x.size(1)\n",
"\n",
" # Эмбеддинги токенов и позиций\n",
" tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n",
" pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]\n",
" \n",
" # Комбинирование\n",
" out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]\n",
" \n",
" # Стек декодеров с передачей кэша\n",
" new_cache = []\n",
" for i, decoder in enumerate(self._decoders):\n",
" decoder_cache = cache[i] if cache is not None else None\n",
" decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n",
"\n",
" # Извлекаем результат из кортежа\n",
" if use_cache:\n",
" out, decoder_new_cache = decoder_result\n",
" new_cache.append(decoder_new_cache)\n",
" else:\n",
" out = decoder_result[0]\n",
"\n",
" out = self._norm(out)\n",
" logits = self._linear(out)\n",
" \n",
" # Возвращаем результат с учетом use_cache\n",
" if use_cache:\n",
" return (logits, new_cache)\n",
" else:\n",
" return (logits, None)\n",
"\n",
" def generate(self,\n",
" x: torch.Tensor, \n",
" max_new_tokens: int, \n",
" do_sample: bool,\n",
" temperature: float = 1.0,\n",
" top_k: int = None,\n",
" top_p: float = None,\n",
" use_cache: bool = True\n",
" ) -> torch.Tensor:\n",
" cache = None\n",
"\n",
" for _ in range(max_new_tokens):\n",
" if use_cache and cache is not None:\n",
" # Используем кэш - передаем только последний токен\n",
" x_input = x[:, -1:] # [batch_size, 1]\n",
" else:\n",
" # Первая итерация или кэш отключен - передаем всю последовательность\n",
" x_input = x\n",
" \n",
" # Прямой проход с кэшем\n",
" logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n",
" \n",
" # Обновляем кэш для следующей итерации\n",
" if use_cache:\n",
" cache = new_cache\n",
"\n",
" last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n",
"\n",
" # Масштабируем логиты температурой\n",
" if temperature > 0:\n",
" logits_scaled = last_logits / temperature\n",
" else:\n",
" logits_scaled = last_logits\n",
"\n",
" if do_sample == True and top_k != None:\n",
" _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n",
"\n",
" # # Заменим все НЕ top-k логиты на -inf\n",
" masked_logits = logits_scaled.clone()\n",
" vocab_size = logits_scaled.size(-1)\n",
"\n",
" # создаём маску: 1, если токен НЕ в topk_indices\n",
" mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n",
" mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n",
" masked_logits[mask.byte()] = float('-inf')\n",
"\n",
" logits_scaled = masked_logits\n",
"\n",
" if do_sample == True and top_p != None:\n",
" # 1. Применим softmax, чтобы получить вероятности:\n",
" probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n",
" # 2. Отсортируем токены по убыванию вероятностей:\n",
" sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
" # 3. Посчитаем кумулятивную сумму вероятностей:\n",
" cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n",
" # 4. Определим маску: оставить токены, пока сумма < top_p\n",
" sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n",
" # Гарантируем, что хотя бы первый токен останется\n",
" sorted_mask[:, 0] = 1\n",
" # 5. Преобразуем маску обратно в оригинальный порядок:\n",
" # Создаём полную маску из 0\n",
" mask = torch.zeros_like(probs, dtype=torch.uint8)\n",
" # Устанавливаем 1 в местах нужных токенов\n",
" mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n",
" # 6. Зануляем логиты токенов вне топ-p:\n",
" logits_scaled[~mask] = float('-inf')\n",
"\n",
" # 4. Применяем Softmax\n",
" probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n",
"\n",
"\n",
" if do_sample == True:\n",
" # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n",
" next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n",
" else:\n",
" # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n",
" next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n",
" \n",
" # 6. Добавляем его к последовательности\n",
" x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n",
" return x\n",
"\n",
" def save(self, path):\n",
" torch.save({\n",
" 'model_state_dict': self.state_dict(),\n",
" 'vocab_size': self._vocab_size,\n",
" 'max_seq_len': self._max_seq_len,\n",
" 'emb_size': self._emb_size,\n",
" 'num_heads': self._num_heads,\n",
" 'head_size': self._head_size,\n",
" 'num_layers': self._num_layers\n",
" }, path)\n",
"\n",
" @classmethod\n",
" def load(cls, path, device):\n",
" checkpoint = torch.load(path, map_location=device)\n",
" model = cls(\n",
" vocab_size=checkpoint['vocab_size'],\n",
" max_seq_len=checkpoint['max_seq_len'],\n",
" emb_size=checkpoint['emb_size'],\n",
" num_heads=checkpoint['num_heads'],\n",
" head_size=checkpoint['head_size'],\n",
" num_layers=checkpoint['num_layers']\n",
" )\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" model.to(device)\n",
" return model\n",
"\n",
" @property\n",
" def max_seq_len(self) -> int:\n",
" return self._max_seq_len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8efb1396",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}