diff --git a/notebooks/llama.ipynb b/notebooks/llama.ipynb new file mode 100644 index 0000000..be03097 --- /dev/null +++ b/notebooks/llama.ipynb @@ -0,0 +1,665 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "efbc675e", + "metadata": {}, + "source": [ + "# Llama\n", + "\n", + "![](https://ucarecdn.com/05af6071-73b2-4067-9a39-632fcb2f24e9/)\n", + "\n", + "\n", + "Llama 1 вышла в феврале 2023 года. Это уже подальше, чем GPT-2. И в ее архитектуре появилось уже больше серьезных изменений:\n", + "\n", + "- Нормализация RMSNorm (вместе с pre-norm).\n", + "- Функция активации SwiGLU.\n", + "- Новый способ кодирования позиций — Rotary Positional Embeddings." + ] + }, + { + "cell_type": "markdown", + "id": "2cedc663", + "metadata": {}, + "source": [ + "# RMSNorm\n", + "\n", + "![2\\_rmsnorm.png](https://ucarecdn.com/2975a217-27ff-4d26-b4a1-cc48a8de1e45/)\n", + "\n", + "В Llama используется более быстрая и эффективная нормализация — **RMSNorm (Root Mean Square Normalization)**.\n", + "И, также как в GPT-2, используется *pre-norm* нормализация, то есть слои нормализации располагаются **перед блоками внимания и FNN**.\n", + "\n", + "RMSNorm отличается от обычной нормализации только одним: в нём исключен этап центрирования (вычитание среднего) и используется только масштабирование по RMS.\n", + "Это сокращает вычислительные затраты (на 7–64%) без существенной потери качества.\n", + "На картинке показана разница в распределении после применения RMSNorm и LayerNorm к исходным данным — RMSNorm не разбросан вокруг нуля.\n", + "\n", + "

\n", + " \"RMSNorm\n", + "

\n", + "\n", + "## Этапы вычисления RMSNorm\n", + "\n", + "1. **Вычисление среднеквадратичного значения:**\n", + "\n", + " $$\\text{RMS}(\\mathbf{x}) = \\sqrt{\\frac{1}{d} \\sum_{j=1}^{d} x_j^2}$$\n", + "\n", + "2. **Нормализация входящего вектора:**\n", + "\n", + " $$\\hat{x}_i = \\frac{x_i}{\\text{RMS}(\\mathbf{x})}$$\n", + "\n", + "3. **Применение масштабирования:**\n", + "\n", + " $$y_i = w_i \\cdot \\hat{x}_i$$\n", + "\n", + "---\n", + "\n", + "**Где:**\n", + "\n", + "* $x_i$ — *i*-й элемент входящего вектора.\n", + "* $w_i$ — *i*-й элемент обучаемого вектора весов.\n", + " Использование весов позволяет модели адаптивно регулировать амплитуду признаков.\n", + " Без них нормализация была бы слишком «жёсткой» и могла бы ограничить качество модели.\n", + "* $d$ — размерность входящего вектора.\n", + "* $\\varepsilon$ — малая константа (например, 1e-6), предотвращает деление на ноль.\n", + "\n", + "---\n", + "\n", + "Так как на вход подаётся тензор, то в векторной форме RMSNorm вычисляется так:\n", + "\n", + "$$\n", + "RMSNorm(x) = w ⊙ \\frac{x}{\\sqrt{mean(x^2) + ϵ}}\n", + "$$\n", + "\n", + "**Где:**\n", + "\n", + "* $x$ — входящий тензор размера `batch_size × ...`\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "873704be", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "\n", + "class RMSNorm(nn.Module):\n", + " def __init__(self, dim: int, eps: float = 1e-6):\n", + " super().__init__()\n", + " self._eps = eps\n", + " self._w = nn.Parameter(torch.ones(dim))\n", + " \n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n", + " rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n", + " norm_x = x / rms\n", + " return self._w * norm_x" + ] + }, + { + "cell_type": "markdown", + "id": "09dd9625", + "metadata": {}, + "source": [ + "# SwiGLU\n", + "\n", + "![3\\_swiglu.png](https://ucarecdn.com/120dea77-2bf2-455f-9b54-c35c4efddc9e/)\n", + "\n", + "В **Llama** ввели новую функцию активации — **SwiGLU (Swish-Gated Linear Unit)** — это гибридная функция активации, которая представляет собой комбинацию трёх линейных преобразований и функции активации **SiLU (Sigmoid Linear Unit)**, она же *Swish* в терминологии Google.\n", + "\n", + "Формула SwiGLU выглядит так:\n", + "\n", + "$$\n", + "\\text{SwiGLU}(x) = \\text{down}(\\text{SiLU}(\\text{gate}(x)) \\otimes \\text{up}(x))\n", + "$$\n", + "\n", + "где:\n", + "\n", + "* $x$ — входящий тензор.\n", + "* $\\text{gate}(x)$ — линейный слой для гейтового механизма. Преобразует вход `x` размерностью `emb_size` в промежуточное представление размерности `4 * emb_size`.\n", + "* $\\text{up}(x)$ — линейный слой для увеличения размерности. Также преобразует `x` в размерность `4 * emb_size`.\n", + "* $\\text{SiLU}(x) = x \\cdot \\sigma(x)$ — функция активации, где $\\sigma$ — сигмоида.\n", + "* $\\otimes$ — поэлементное умножение.\n", + "* $\\text{down}(x)$ — линейный слой для уменьшения промежуточного представления до исходного размера (`emb_size`).\n", + "\n", + "> **Гейтинг** (от слова *gate* — «врата») — это механизм, который позволяет сети динамически фильтровать, какая информация должна проходить дальше.\n", + "> При гейтинге создаются как бы два независимых потока:\n", + ">\n", + "> * один предназначен для прямой передачи информации (*up-down*),\n", + "> * другой — для контроля передаваемой информации (*gate*).\n", + ">\n", + "> Это позволяет сети учить более сложные паттерны.\n", + "> Например, гейт может научиться:\n", + "> «если признак A активен, то пропусти признак B»,\n", + "> что невозможно с простой функцией активации между линейными слоями.\n", + ">\n", + "> Также гейтинг помогает с затуханием градиентов: вместо простого обнуления (как в ReLU), гейт может тонко модулировать силу сигнала.\n", + "\n", + "SwiGLU более сложная (дорогая), чем ReLU/GELU, так как требует больше вычислений (три линейных преобразования вместо двух).\n", + "Но при этом показывает лучшее качество по сравнению с ReLU и GELU.\n", + "\n", + "График **SiLU** похож на **GELU**, но более гладкий:\n", + "\n", + "

\n", + " \"SiLU\n", + "

\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "0484cf77", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "class SiLU(nn.Module):\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n", + " return torch.sigmoid(x) * x" + ] + }, + { + "cell_type": "markdown", + "id": "0b64da5d", + "metadata": {}, + "source": [ + "## SwiGLU" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "74ca39ba", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "\n", + "class SwiGLU(nn.Module):\n", + " def __init__(self, emb_size: int, dropout: float = 0.1):\n", + " super().__init__()\n", + "\n", + " self._gate = nn.Linear(emb_size, 4 * emb_size)\n", + " self._up = nn.Linear(emb_size, 4 * emb_size)\n", + " self._down = nn.Linear(4 * emb_size, emb_size)\n", + " self._actvation = SiLU()\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n", + " gate_out = self._gate(x) # [batch, seq, 4*emb]\n", + " activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n", + " up_out = self._up(x) # [batch, seq, 4*emb]\n", + " out = up_out * activation_out # поэлементное!\n", + " out = self._down(out) # [batch, seq, emb]\n", + " return self._dropout(out)\n", + "\n", + " " + ] + }, + { + "cell_type": "markdown", + "id": "c386a55c", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "c9fe652d", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "fe1274b1", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from math import sqrt\n", + "import torch\n", + "from torch import nn\n", + "from torch import Tensor\n", + "\n", + "class TokenEmbeddings(nn.Module):\n", + " def __init__(self, vocab_size: int, emb_size: int):\n", + " super().__init__()\n", + " self._embedding = nn.Embedding(\n", + " num_embeddings=vocab_size,\n", + " embedding_dim=emb_size\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " return self._embedding(x)\n", + "\n", + " @property\n", + " def num_embeddings(self) -> int:\n", + " return self._embedding.num_embeddings\n", + "\n", + " @property\n", + " def embedding_dim(self) -> int:\n", + " return self._embedding.embedding_dim\n", + "\n", + "\n", + "import torch\n", + "from torch import nn, Tensor\n", + "\n", + "class PositionalEmbeddings(nn.Module):\n", + " def __init__(self, max_seq_len: int, emb_size: int):\n", + " super().__init__()\n", + " self.max_seq_len = max_seq_len\n", + " self.emb_size = emb_size\n", + " self.embedding = nn.Embedding(\n", + " num_embeddings=max_seq_len,\n", + " embedding_dim=emb_size\n", + " )\n", + "\n", + " def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:\n", + " if seq_len < 1 or seq_len > self.max_seq_len:\n", + " raise IndexError(f\"Длина {seq_len} должна быть от 1 до {self.max_seq_len}\")\n", + " if start_pos == 0:\n", + " positions = torch.arange(seq_len, device=self.embedding.weight.device)\n", + " else:\n", + " positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)\n", + " return self.embedding(positions)\n", + " \n", + " \n", + "class HeadAttention(nn.Module):\n", + "\n", + " def __init__(self, emb_size: int, head_size: int, max_seq_len: int):\n", + " super().__init__()\n", + " self._emb_size = emb_size\n", + " self._head_size = head_size\n", + " self._max_seq_len = max_seq_len\n", + "\n", + " self._k = nn.Linear(emb_size, head_size)\n", + " self._q = nn.Linear(emb_size, head_size)\n", + " self._v = nn.Linear(emb_size, head_size)\n", + "\n", + " mask = torch.tril(torch.ones(max_seq_len, max_seq_len))\n", + " self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n", + "\n", + " def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n", + " seq_len = x.shape[1]\n", + " if seq_len > self._max_seq_len:\n", + " raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n", + "\n", + " k = self._k(x) # [B, T, hs]\n", + " q = self._q(x) # [B, T, hs]\n", + " v = self._v(x) # [B, T, hs]\n", + "\n", + " if cache is not None:\n", + " k_cache, v_cache = cache\n", + " k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]\n", + " v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]\n", + " \n", + " scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)\n", + " \n", + " if cache is None:\n", + " scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))\n", + " \n", + " weights = F.softmax(scores, dim=-1)\n", + " x_out = weights @ v # [B, T, hs]\n", + "\n", + " if use_cache is True:\n", + " return (x_out, (k, v))\n", + " else:\n", + " return (x_out, None)\n", + " \n", + "from torch import nn\n", + "import torch\n", + "import math\n", + "\n", + "class MultiHeadAttention(nn.Module):\n", + " def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, dropout: float = 0.1):\n", + "\n", + " super().__init__()\n", + " self._heads = nn.ModuleList([\n", + " HeadAttention(\n", + " emb_size=emb_size, \n", + " head_size=head_size, \n", + " max_seq_len=max_seq_len\n", + " ) for _ in range(num_heads)\n", + " ])\n", + " self._layer = nn.Linear(head_size * num_heads, emb_size)\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):\n", + "\n", + " attention_results = []\n", + " for i, head in enumerate(self._heads):\n", + " head_cache = cache[i] if cache is not None else None\n", + " result = head(x, use_cache=use_cache, cache=head_cache)\n", + " attention_results.append(result)\n", + " \n", + " outputs, caches = zip(*attention_results)\n", + " attention_outputs = list(outputs)\n", + " kv_caches = list(caches)\n", + " \n", + " concatenated_attention = torch.cat(attention_outputs, dim=-1)\n", + "\n", + " projected_output = self._layer(concatenated_attention)\n", + " \n", + " final_output = self._dropout(projected_output)\n", + " \n", + " if use_cache is True:\n", + " return (final_output, kv_caches)\n", + " else:\n", + " return (final_output, None)\n", + "\n", + "\n", + "class GELU(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n", + " \n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return 0.5 * x * (1 + torch.tanh(\n", + " self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n", + " ))\n", + "\n", + "class FeedForward(nn.Module):\n", + "\n", + " def __init__(self, emb_size: int, dropout: float = 0.1):\n", + " super().__init__()\n", + " self._layer1 = nn.Linear(emb_size, emb_size * 4)\n", + " self._gelu = GELU()\n", + " self._layer2 = nn.Linear(emb_size * 4, emb_size)\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " input_dtype = x.dtype\n", + " \n", + " if input_dtype != self._layer1.weight.dtype:\n", + " self._layer1 = self._layer1.to(dtype=input_dtype)\n", + " self._layer2 = self._layer2.to(dtype=input_dtype)\n", + " \n", + " x = self._layer1(x)\n", + " x = self._gelu(x)\n", + " x = self._layer2(x)\n", + " return self._dropout(x)\n", + " \n", + "class Decoder(nn.Module):\n", + " def __init__(self, \n", + " num_heads: int,\n", + " emb_size: int,\n", + " head_size: int,\n", + " max_seq_len: int,\n", + " dropout: float = 0.1\n", + " ):\n", + " super().__init__()\n", + " self._heads = MultiHeadAttention(\n", + " num_heads=num_heads, \n", + " emb_size=emb_size, \n", + " head_size=head_size, \n", + " max_seq_len=max_seq_len, \n", + " dropout=dropout\n", + " )\n", + " self._ff = FeedForward(emb_size=emb_size, dropout=dropout)\n", + " self._norm1 = RMSNorm(emb_size)\n", + " self._norm2 = RMSNorm(emb_size)\n", + "\n", + " def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n", + " norm1_out = self._norm1(x)\n", + " attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n", + " out = attention + x\n", + " \n", + " norm2_out = self._norm2(out)\n", + " ffn_out = self._ff(norm2_out)\n", + "\n", + " if use_cache is True:\n", + " return (ffn_out + out, kv_caches)\n", + " else:\n", + " return (ffn_out + out, None)\n", + "\n", + "\n", + "\n", + "from torch import nn\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "class Llama(nn.Module):\n", + " def __init__(self,\n", + " vocab_size: int,\n", + " max_seq_len: int,\n", + " emb_size: int,\n", + " num_heads: int,\n", + " head_size: int,\n", + " num_layers: int,\n", + " dropout: float = 0.1,\n", + " device: str = 'cpu'\n", + " ):\n", + " super().__init__()\n", + " self._vocab_size = vocab_size\n", + " self._max_seq_len = max_seq_len\n", + " self._emb_size = emb_size\n", + " self._num_heads = num_heads\n", + " self._head_size = head_size\n", + " self._num_layers = num_layers\n", + " self._dropout = dropout\n", + " self._device = device\n", + " \n", + " self.validation_loss = None\n", + "\n", + " # Инициализация слоев\n", + " self._token_embeddings = TokenEmbeddings(\n", + " vocab_size=vocab_size, \n", + " emb_size=emb_size\n", + " )\n", + " self._position_embeddings = PositionalEmbeddings(\n", + " max_seq_len=max_seq_len, \n", + " emb_size=emb_size\n", + " )\n", + " self._dropout = nn.Dropout(dropout)\n", + " self._decoders = nn.ModuleList([Decoder(\n", + " num_heads=num_heads,\n", + " emb_size=emb_size,\n", + " head_size=head_size,\n", + " max_seq_len=max_seq_len,\n", + " dropout=dropout \n", + " ) for _ in range(num_layers)])\n", + " self._norm = RMSNorm(emb_size)\n", + " self._linear = nn.Linear(emb_size, vocab_size)\n", + "\n", + " def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n", + " # Проверка длины последовательности (только при отсутствии кэша)\n", + " if cache is None and x.size(1) > self._max_seq_len:\n", + " raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n", + " \n", + " \n", + " # Вычисление start_pos из кэша (если кэш передан)\n", + " if cache is not None:\n", + " # При кэше обрабатываем только один токен (последний)\n", + " seq_len = 1\n", + " # Вычисляем start_pos из самого нижнего уровня кэша\n", + " if cache and cache[0] and cache[0][0]:\n", + " key_cache, _ = cache[0][0] # Первый декодер, первая голова\n", + " start_pos = key_cache.size(1) # cache_len\n", + " else:\n", + " start_pos = 0\n", + " else:\n", + " # Без кэша работаем как раньше\n", + " start_pos = 0\n", + " seq_len = x.size(1)\n", + "\n", + " # Эмбеддинги токенов и позиций\n", + " tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n", + " pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]\n", + " \n", + " # Комбинирование\n", + " out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]\n", + " \n", + " # Стек декодеров с передачей кэша\n", + " new_cache = []\n", + " for i, decoder in enumerate(self._decoders):\n", + " decoder_cache = cache[i] if cache is not None else None\n", + " decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n", + "\n", + " # Извлекаем результат из кортежа\n", + " if use_cache:\n", + " out, decoder_new_cache = decoder_result\n", + " new_cache.append(decoder_new_cache)\n", + " else:\n", + " out = decoder_result[0]\n", + "\n", + " out = self._norm(out)\n", + " logits = self._linear(out)\n", + " \n", + " # Возвращаем результат с учетом use_cache\n", + " if use_cache:\n", + " return (logits, new_cache)\n", + " else:\n", + " return (logits, None)\n", + "\n", + " def generate(self,\n", + " x: torch.Tensor, \n", + " max_new_tokens: int, \n", + " do_sample: bool,\n", + " temperature: float = 1.0,\n", + " top_k: int = None,\n", + " top_p: float = None,\n", + " use_cache: bool = True\n", + " ) -> torch.Tensor:\n", + " cache = None\n", + "\n", + " for _ in range(max_new_tokens):\n", + " if use_cache and cache is not None:\n", + " # Используем кэш - передаем только последний токен\n", + " x_input = x[:, -1:] # [batch_size, 1]\n", + " else:\n", + " # Первая итерация или кэш отключен - передаем всю последовательность\n", + " x_input = x\n", + " \n", + " # Прямой проход с кэшем\n", + " logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n", + " \n", + " # Обновляем кэш для следующей итерации\n", + " if use_cache:\n", + " cache = new_cache\n", + "\n", + " last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n", + "\n", + " # Масштабируем логиты температурой\n", + " if temperature > 0:\n", + " logits_scaled = last_logits / temperature\n", + " else:\n", + " logits_scaled = last_logits\n", + "\n", + " if do_sample == True and top_k != None:\n", + " _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n", + "\n", + " # # Заменим все НЕ top-k логиты на -inf\n", + " masked_logits = logits_scaled.clone()\n", + " vocab_size = logits_scaled.size(-1)\n", + "\n", + " # создаём маску: 1, если токен НЕ в topk_indices\n", + " mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n", + " mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n", + " masked_logits[mask.byte()] = float('-inf')\n", + "\n", + " logits_scaled = masked_logits\n", + "\n", + " if do_sample == True and top_p != None:\n", + " # 1. Применим softmax, чтобы получить вероятности:\n", + " probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n", + " # 2. Отсортируем токены по убыванию вероятностей:\n", + " sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n", + " # 3. Посчитаем кумулятивную сумму вероятностей:\n", + " cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n", + " # 4. Определим маску: оставить токены, пока сумма < top_p\n", + " sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n", + " # Гарантируем, что хотя бы первый токен останется\n", + " sorted_mask[:, 0] = 1\n", + " # 5. Преобразуем маску обратно в оригинальный порядок:\n", + " # Создаём полную маску из 0\n", + " mask = torch.zeros_like(probs, dtype=torch.uint8)\n", + " # Устанавливаем 1 в местах нужных токенов\n", + " mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n", + " # 6. Зануляем логиты токенов вне топ-p:\n", + " logits_scaled[~mask] = float('-inf')\n", + "\n", + " # 4. Применяем Softmax\n", + " probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n", + "\n", + "\n", + " if do_sample == True:\n", + " # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n", + " next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n", + " else:\n", + " # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n", + " next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n", + " \n", + " # 6. Добавляем его к последовательности\n", + " x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n", + " return x\n", + "\n", + " def save(self, path):\n", + " torch.save({\n", + " 'model_state_dict': self.state_dict(),\n", + " 'vocab_size': self._vocab_size,\n", + " 'max_seq_len': self._max_seq_len,\n", + " 'emb_size': self._emb_size,\n", + " 'num_heads': self._num_heads,\n", + " 'head_size': self._head_size,\n", + " 'num_layers': self._num_layers\n", + " }, path)\n", + "\n", + " @classmethod\n", + " def load(cls, path, device):\n", + " checkpoint = torch.load(path, map_location=device)\n", + " model = cls(\n", + " vocab_size=checkpoint['vocab_size'],\n", + " max_seq_len=checkpoint['max_seq_len'],\n", + " emb_size=checkpoint['emb_size'],\n", + " num_heads=checkpoint['num_heads'],\n", + " head_size=checkpoint['head_size'],\n", + " num_layers=checkpoint['num_layers']\n", + " )\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " model.to(device)\n", + " return model\n", + "\n", + " @property\n", + " def max_seq_len(self) -> int:\n", + " return self._max_seq_len" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8efb1396", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}