fix: typo in activation attribute for SwiGLU (rename _actvation to _activation) and minor index update

This commit is contained in:
Sergey Penkovsky
2025-10-05 23:01:58 +03:00
parent e5b5a97811
commit b6f56a2640

665
notebooks/llama.ipynb Normal file
View File

@@ -0,0 +1,665 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "efbc675e",
"metadata": {},
"source": [
"# Llama\n",
"\n",
"![](https://ucarecdn.com/05af6071-73b2-4067-9a39-632fcb2f24e9/)\n",
"\n",
"\n",
"Llama 1 вышла в феврале 2023 года. Это уже подальше, чем GPT-2. И в ее архитектуре появилось уже больше серьезных изменений:\n",
"\n",
"- Нормализация RMSNorm (вместе с pre-norm).\n",
"- Функция активации SwiGLU.\n",
"- Новый способ кодирования позиций — Rotary Positional Embeddings."
]
},
{
"cell_type": "markdown",
"id": "2cedc663",
"metadata": {},
"source": [
"# RMSNorm\n",
"\n",
"![2\\_rmsnorm.png](https://ucarecdn.com/2975a217-27ff-4d26-b4a1-cc48a8de1e45/)\n",
"\n",
"В Llama используется более быстрая и эффективная нормализация — **RMSNorm (Root Mean Square Normalization)**.\n",
"И, также как в GPT-2, используется *pre-norm* нормализация, то есть слои нормализации располагаются **перед блоками внимания и FNN**.\n",
"\n",
"RMSNorm отличается от обычной нормализации только одним: в нём исключен этап центрирования (вычитание среднего) и используется только масштабирование по RMS.\n",
"Это сокращает вычислительные затраты (на 764%) без существенной потери качества.\n",
"На картинке показана разница в распределении после применения RMSNorm и LayerNorm к исходным данным — RMSNorm не разбросан вокруг нуля.\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/cbfbb78e-e2b0-40e2-ba56-73e5114d54f6/\" width=\"350\" alt=\"RMSNorm vs LayerNorm\">\n",
"</p>\n",
"\n",
"## Этапы вычисления RMSNorm\n",
"\n",
"1. **Вычисление среднеквадратичного значения:**\n",
"\n",
" $$\\text{RMS}(\\mathbf{x}) = \\sqrt{\\frac{1}{d} \\sum_{j=1}^{d} x_j^2}$$\n",
"\n",
"2. **Нормализация входящего вектора:**\n",
"\n",
" $$\\hat{x}_i = \\frac{x_i}{\\text{RMS}(\\mathbf{x})}$$\n",
"\n",
"3. **Применение масштабирования:**\n",
"\n",
" $$y_i = w_i \\cdot \\hat{x}_i$$\n",
"\n",
"---\n",
"\n",
"**Где:**\n",
"\n",
"* $x_i$ — *i*-й элемент входящего вектора.\n",
"* $w_i$ — *i*-й элемент обучаемого вектора весов.\n",
" Использование весов позволяет модели адаптивно регулировать амплитуду признаков.\n",
" Без них нормализация была бы слишком «жёсткой» и могла бы ограничить качество модели.\n",
"* $d$ — размерность входящего вектора.\n",
"* $\\varepsilon$ — малая константа (например, 1e-6), предотвращает деление на ноль.\n",
"\n",
"---\n",
"\n",
"Так как на вход подаётся тензор, то в векторной форме RMSNorm вычисляется так:\n",
"\n",
"$$\n",
"RMSNorm(x) = w ⊙ \\frac{x}{\\sqrt{mean(x^2) + ϵ}}\n",
"$$\n",
"\n",
"**Где:**\n",
"\n",
"* $x$ — входящий тензор размера `batch_size × ...`\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "873704be",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6):\n",
" super().__init__()\n",
" self._eps = eps\n",
" self._w = nn.Parameter(torch.ones(dim))\n",
" \n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n",
" norm_x = x / rms\n",
" return self._w * norm_x"
]
},
{
"cell_type": "markdown",
"id": "09dd9625",
"metadata": {},
"source": [
"# SwiGLU\n",
"\n",
"![3\\_swiglu.png](https://ucarecdn.com/120dea77-2bf2-455f-9b54-c35c4efddc9e/)\n",
"\n",
"В **Llama** ввели новую функцию активации — **SwiGLU (Swish-Gated Linear Unit)** — это гибридная функция активации, которая представляет собой комбинацию трёх линейных преобразований и функции активации **SiLU (Sigmoid Linear Unit)**, она же *Swish* в терминологии Google.\n",
"\n",
"Формула SwiGLU выглядит так:\n",
"\n",
"$$\n",
"\\text{SwiGLU}(x) = \\text{down}(\\text{SiLU}(\\text{gate}(x)) \\otimes \\text{up}(x))\n",
"$$\n",
"\n",
"где:\n",
"\n",
"* $x$ — входящий тензор.\n",
"* $\\text{gate}(x)$ — линейный слой для гейтового механизма. Преобразует вход `x` размерностью `emb_size` в промежуточное представление размерности `4 * emb_size`.\n",
"* $\\text{up}(x)$ — линейный слой для увеличения размерности. Также преобразует `x` в размерность `4 * emb_size`.\n",
"* $\\text{SiLU}(x) = x \\cdot \\sigma(x)$ — функция активации, где $\\sigma$ — сигмоида.\n",
"* $\\otimes$ — поэлементное умножение.\n",
"* $\\text{down}(x)$ — линейный слой для уменьшения промежуточного представления до исходного размера (`emb_size`).\n",
"\n",
"> **Гейтинг** (от слова *gate* — «врата») — это механизм, который позволяет сети динамически фильтровать, какая информация должна проходить дальше.\n",
"> При гейтинге создаются как бы два независимых потока:\n",
">\n",
"> * один предназначен для прямой передачи информации (*up-down*),\n",
"> * другой — для контроля передаваемой информации (*gate*).\n",
">\n",
"> Это позволяет сети учить более сложные паттерны.\n",
"> Например, гейт может научиться:\n",
"> «если признак A активен, то пропусти признак B»,\n",
"> что невозможно с простой функцией активации между линейными слоями.\n",
">\n",
"> Также гейтинг помогает с затуханием градиентов: вместо простого обнуления (как в ReLU), гейт может тонко модулировать силу сигнала.\n",
"\n",
"SwiGLU более сложная (дорогая), чем ReLU/GELU, так как требует больше вычислений (три линейных преобразования вместо двух).\n",
"Но при этом показывает лучшее качество по сравнению с ReLU и GELU.\n",
"\n",
"График **SiLU** похож на **GELU**, но более гладкий:\n",
"\n",
"<p align=\"center\">\n",
" <img src=\"https://ucarecdn.com/6683e0c8-96b7-4389-826a-a73708b4a835/\" width=\"500\" alt=\"SiLU vs GELU\">\n",
"</p>\n"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0484cf77",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"\n",
"class SiLU(nn.Module):\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n",
" return torch.sigmoid(x) * x"
]
},
{
"cell_type": "markdown",
"id": "0b64da5d",
"metadata": {},
"source": [
"## SwiGLU"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74ca39ba",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"class SwiGLU(nn.Module):\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
"\n",
" self._gate = nn.Linear(emb_size, 4 * emb_size)\n",
" self._up = nn.Linear(emb_size, 4 * emb_size)\n",
" self._down = nn.Linear(4 * emb_size, emb_size)\n",
" self._actvation = SiLU()\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n",
" gate_out = self._gate(x) # [batch, seq, 4*emb]\n",
" activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n",
" up_out = self._up(x) # [batch, seq, 4*emb]\n",
" out = up_out * activation_out # поэлементное!\n",
" out = self._down(out) # [batch, seq, emb]\n",
" return self._dropout(out)\n",
"\n",
" "
]
},
{
"cell_type": "markdown",
"id": "c386a55c",
"metadata": {},
"source": []
},
{
"cell_type": "markdown",
"id": "c9fe652d",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fe1274b1",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch import nn\n",
"import torch.nn.functional as F\n",
"from math import sqrt\n",
"import torch\n",
"from torch import nn\n",
"from torch import Tensor\n",
"\n",
"class TokenEmbeddings(nn.Module):\n",
" def __init__(self, vocab_size: int, emb_size: int):\n",
" super().__init__()\n",
" self._embedding = nn.Embedding(\n",
" num_embeddings=vocab_size,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, x: Tensor) -> Tensor:\n",
" return self._embedding(x)\n",
"\n",
" @property\n",
" def num_embeddings(self) -> int:\n",
" return self._embedding.num_embeddings\n",
"\n",
" @property\n",
" def embedding_dim(self) -> int:\n",
" return self._embedding.embedding_dim\n",
"\n",
"\n",
"import torch\n",
"from torch import nn, Tensor\n",
"\n",
"class PositionalEmbeddings(nn.Module):\n",
" def __init__(self, max_seq_len: int, emb_size: int):\n",
" super().__init__()\n",
" self.max_seq_len = max_seq_len\n",
" self.emb_size = emb_size\n",
" self.embedding = nn.Embedding(\n",
" num_embeddings=max_seq_len,\n",
" embedding_dim=emb_size\n",
" )\n",
"\n",
" def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:\n",
" if seq_len < 1 or seq_len > self.max_seq_len:\n",
" raise IndexError(f\"Длина {seq_len} должна быть от 1 до {self.max_seq_len}\")\n",
" if start_pos == 0:\n",
" positions = torch.arange(seq_len, device=self.embedding.weight.device)\n",
" else:\n",
" positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)\n",
" return self.embedding(positions)\n",
" \n",
" \n",
"class HeadAttention(nn.Module):\n",
"\n",
" def __init__(self, emb_size: int, head_size: int, max_seq_len: int):\n",
" super().__init__()\n",
" self._emb_size = emb_size\n",
" self._head_size = head_size\n",
" self._max_seq_len = max_seq_len\n",
"\n",
" self._k = nn.Linear(emb_size, head_size)\n",
" self._q = nn.Linear(emb_size, head_size)\n",
" self._v = nn.Linear(emb_size, head_size)\n",
"\n",
" mask = torch.tril(torch.ones(max_seq_len, max_seq_len))\n",
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
" seq_len = x.shape[1]\n",
" if seq_len > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
"\n",
" k = self._k(x) # [B, T, hs]\n",
" q = self._q(x) # [B, T, hs]\n",
" v = self._v(x) # [B, T, hs]\n",
"\n",
" if cache is not None:\n",
" k_cache, v_cache = cache\n",
" k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]\n",
" v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]\n",
" \n",
" scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)\n",
" \n",
" if cache is None:\n",
" scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))\n",
" \n",
" weights = F.softmax(scores, dim=-1)\n",
" x_out = weights @ v # [B, T, hs]\n",
"\n",
" if use_cache is True:\n",
" return (x_out, (k, v))\n",
" else:\n",
" return (x_out, None)\n",
" \n",
"from torch import nn\n",
"import torch\n",
"import math\n",
"\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):\n",
"\n",
" super().__init__()\n",
" self._heads = nn.ModuleList([\n",
" HeadAttention(\n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len\n",
" ) for _ in range(num_heads)\n",
" ])\n",
" self._layer = nn.Linear(head_size * num_heads, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):\n",
"\n",
" attention_results = []\n",
" for i, head in enumerate(self._heads):\n",
" head_cache = cache[i] if cache is not None else None\n",
" result = head(x, use_cache=use_cache, cache=head_cache)\n",
" attention_results.append(result)\n",
" \n",
" outputs, caches = zip(*attention_results)\n",
" attention_outputs = list(outputs)\n",
" kv_caches = list(caches)\n",
" \n",
" concatenated_attention = torch.cat(attention_outputs, dim=-1)\n",
"\n",
" projected_output = self._layer(concatenated_attention)\n",
" \n",
" final_output = self._dropout(projected_output)\n",
" \n",
" if use_cache is True:\n",
" return (final_output, kv_caches)\n",
" else:\n",
" return (final_output, None)\n",
"\n",
"\n",
"class GELU(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n",
" \n",
" def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
" return 0.5 * x * (1 + torch.tanh(\n",
" self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n",
" ))\n",
"\n",
"class FeedForward(nn.Module):\n",
"\n",
" def __init__(self, emb_size: int, dropout: float = 0.1):\n",
" super().__init__()\n",
" self._layer1 = nn.Linear(emb_size, emb_size * 4)\n",
" self._gelu = GELU()\n",
" self._layer2 = nn.Linear(emb_size * 4, emb_size)\n",
" self._dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" input_dtype = x.dtype\n",
" \n",
" if input_dtype != self._layer1.weight.dtype:\n",
" self._layer1 = self._layer1.to(dtype=input_dtype)\n",
" self._layer2 = self._layer2.to(dtype=input_dtype)\n",
" \n",
" x = self._layer1(x)\n",
" x = self._gelu(x)\n",
" x = self._layer2(x)\n",
" return self._dropout(x)\n",
" \n",
"class Decoder(nn.Module):\n",
" def __init__(self, \n",
" num_heads: int,\n",
" emb_size: int,\n",
" head_size: int,\n",
" max_seq_len: int,\n",
" dropout: float = 0.1\n",
" ):\n",
" super().__init__()\n",
" self._heads = MultiHeadAttention(\n",
" num_heads=num_heads, \n",
" emb_size=emb_size, \n",
" head_size=head_size, \n",
" max_seq_len=max_seq_len, \n",
" dropout=dropout\n",
" )\n",
" self._ff = FeedForward(emb_size=emb_size, dropout=dropout)\n",
" self._norm1 = RMSNorm(emb_size)\n",
" self._norm2 = RMSNorm(emb_size)\n",
"\n",
" def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n",
" norm1_out = self._norm1(x)\n",
" attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n",
" out = attention + x\n",
" \n",
" norm2_out = self._norm2(out)\n",
" ffn_out = self._ff(norm2_out)\n",
"\n",
" if use_cache is True:\n",
" return (ffn_out + out, kv_caches)\n",
" else:\n",
" return (ffn_out + out, None)\n",
"\n",
"\n",
"\n",
"from torch import nn\n",
"import torch\n",
"import torch.nn.functional as F\n",
"\n",
"class Llama(nn.Module):\n",
" def __init__(self,\n",
" vocab_size: int,\n",
" max_seq_len: int,\n",
" emb_size: int,\n",
" num_heads: int,\n",
" head_size: int,\n",
" num_layers: int,\n",
" dropout: float = 0.1,\n",
" device: str = 'cpu'\n",
" ):\n",
" super().__init__()\n",
" self._vocab_size = vocab_size\n",
" self._max_seq_len = max_seq_len\n",
" self._emb_size = emb_size\n",
" self._num_heads = num_heads\n",
" self._head_size = head_size\n",
" self._num_layers = num_layers\n",
" self._dropout = dropout\n",
" self._device = device\n",
" \n",
" self.validation_loss = None\n",
"\n",
" # Инициализация слоев\n",
" self._token_embeddings = TokenEmbeddings(\n",
" vocab_size=vocab_size, \n",
" emb_size=emb_size\n",
" )\n",
" self._position_embeddings = PositionalEmbeddings(\n",
" max_seq_len=max_seq_len, \n",
" emb_size=emb_size\n",
" )\n",
" self._dropout = nn.Dropout(dropout)\n",
" self._decoders = nn.ModuleList([Decoder(\n",
" num_heads=num_heads,\n",
" emb_size=emb_size,\n",
" head_size=head_size,\n",
" max_seq_len=max_seq_len,\n",
" dropout=dropout \n",
" ) for _ in range(num_layers)])\n",
" self._norm = RMSNorm(emb_size)\n",
" self._linear = nn.Linear(emb_size, vocab_size)\n",
"\n",
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n",
" # Проверка длины последовательности (только при отсутствии кэша)\n",
" if cache is None and x.size(1) > self._max_seq_len:\n",
" raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n",
" \n",
" \n",
" # Вычисление start_pos из кэша (если кэш передан)\n",
" if cache is not None:\n",
" # При кэше обрабатываем только один токен (последний)\n",
" seq_len = 1\n",
" # Вычисляем start_pos из самого нижнего уровня кэша\n",
" if cache and cache[0] and cache[0][0]:\n",
" key_cache, _ = cache[0][0] # Первый декодер, первая голова\n",
" start_pos = key_cache.size(1) # cache_len\n",
" else:\n",
" start_pos = 0\n",
" else:\n",
" # Без кэша работаем как раньше\n",
" start_pos = 0\n",
" seq_len = x.size(1)\n",
"\n",
" # Эмбеддинги токенов и позиций\n",
" tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n",
" pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]\n",
" \n",
" # Комбинирование\n",
" out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]\n",
" \n",
" # Стек декодеров с передачей кэша\n",
" new_cache = []\n",
" for i, decoder in enumerate(self._decoders):\n",
" decoder_cache = cache[i] if cache is not None else None\n",
" decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n",
"\n",
" # Извлекаем результат из кортежа\n",
" if use_cache:\n",
" out, decoder_new_cache = decoder_result\n",
" new_cache.append(decoder_new_cache)\n",
" else:\n",
" out = decoder_result[0]\n",
"\n",
" out = self._norm(out)\n",
" logits = self._linear(out)\n",
" \n",
" # Возвращаем результат с учетом use_cache\n",
" if use_cache:\n",
" return (logits, new_cache)\n",
" else:\n",
" return (logits, None)\n",
"\n",
" def generate(self,\n",
" x: torch.Tensor, \n",
" max_new_tokens: int, \n",
" do_sample: bool,\n",
" temperature: float = 1.0,\n",
" top_k: int = None,\n",
" top_p: float = None,\n",
" use_cache: bool = True\n",
" ) -> torch.Tensor:\n",
" cache = None\n",
"\n",
" for _ in range(max_new_tokens):\n",
" if use_cache and cache is not None:\n",
" # Используем кэш - передаем только последний токен\n",
" x_input = x[:, -1:] # [batch_size, 1]\n",
" else:\n",
" # Первая итерация или кэш отключен - передаем всю последовательность\n",
" x_input = x\n",
" \n",
" # Прямой проход с кэшем\n",
" logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n",
" \n",
" # Обновляем кэш для следующей итерации\n",
" if use_cache:\n",
" cache = new_cache\n",
"\n",
" last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n",
"\n",
" # Масштабируем логиты температурой\n",
" if temperature > 0:\n",
" logits_scaled = last_logits / temperature\n",
" else:\n",
" logits_scaled = last_logits\n",
"\n",
" if do_sample == True and top_k != None:\n",
" _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n",
"\n",
" # # Заменим все НЕ top-k логиты на -inf\n",
" masked_logits = logits_scaled.clone()\n",
" vocab_size = logits_scaled.size(-1)\n",
"\n",
" # создаём маску: 1, если токен НЕ в topk_indices\n",
" mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n",
" mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n",
" masked_logits[mask.byte()] = float('-inf')\n",
"\n",
" logits_scaled = masked_logits\n",
"\n",
" if do_sample == True and top_p != None:\n",
" # 1. Применим softmax, чтобы получить вероятности:\n",
" probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n",
" # 2. Отсортируем токены по убыванию вероятностей:\n",
" sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n",
" # 3. Посчитаем кумулятивную сумму вероятностей:\n",
" cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n",
" # 4. Определим маску: оставить токены, пока сумма < top_p\n",
" sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n",
" # Гарантируем, что хотя бы первый токен останется\n",
" sorted_mask[:, 0] = 1\n",
" # 5. Преобразуем маску обратно в оригинальный порядок:\n",
" # Создаём полную маску из 0\n",
" mask = torch.zeros_like(probs, dtype=torch.uint8)\n",
" # Устанавливаем 1 в местах нужных токенов\n",
" mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n",
" # 6. Зануляем логиты токенов вне топ-p:\n",
" logits_scaled[~mask] = float('-inf')\n",
"\n",
" # 4. Применяем Softmax\n",
" probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n",
"\n",
"\n",
" if do_sample == True:\n",
" # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n",
" next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n",
" else:\n",
" # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n",
" next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n",
" \n",
" # 6. Добавляем его к последовательности\n",
" x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n",
" return x\n",
"\n",
" def save(self, path):\n",
" torch.save({\n",
" 'model_state_dict': self.state_dict(),\n",
" 'vocab_size': self._vocab_size,\n",
" 'max_seq_len': self._max_seq_len,\n",
" 'emb_size': self._emb_size,\n",
" 'num_heads': self._num_heads,\n",
" 'head_size': self._head_size,\n",
" 'num_layers': self._num_layers\n",
" }, path)\n",
"\n",
" @classmethod\n",
" def load(cls, path, device):\n",
" checkpoint = torch.load(path, map_location=device)\n",
" model = cls(\n",
" vocab_size=checkpoint['vocab_size'],\n",
" max_seq_len=checkpoint['max_seq_len'],\n",
" emb_size=checkpoint['emb_size'],\n",
" num_heads=checkpoint['num_heads'],\n",
" head_size=checkpoint['head_size'],\n",
" num_layers=checkpoint['num_layers']\n",
" )\n",
" model.load_state_dict(checkpoint['model_state_dict'])\n",
" model.to(device)\n",
" return model\n",
"\n",
" @property\n",
" def max_seq_len(self) -> int:\n",
" return self._max_seq_len"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8efb1396",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}