From e5706a690d67bb3ab2a136159a6414934e24ce40 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Tue, 14 Oct 2025 12:03:20 +0300 Subject: [PATCH] =?UTF-8?q?fix(rope,=20attention):=20=D0=BA=D0=BE=D1=80?= =?UTF-8?q?=D1=80=D0=B5=D0=BA=D1=82=D0=BD=D0=BE=D0=B5=20=D0=BF=D0=BE=D0=B7?= =?UTF-8?q?=D0=B8=D1=86=D0=B8=D0=BE=D0=BD=D0=B8=D1=80=D0=BE=D0=B2=D0=B0?= =?UTF-8?q?=D0=BD=D0=B8=D0=B5=20RoPE=20=D0=BF=D1=80=D0=B8=20=D0=B3=D0=B5?= =?UTF-8?q?=D0=BD=D0=B5=D1=80=D0=B0=D1=86=D0=B8=D0=B8=20=D1=81=20=D0=BA?= =?UTF-8?q?=D1=8D=D1=88=D0=B5=D0=BC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша. - В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша. - Обновлена сигнатура и логика метода RoPE.forward. - Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы. BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную. --- llm/src/llm/core/head_attention.py | 10 +- llm/src/llm/core/rope.py | 10 +- notebooks/llama.ipynb | 275 +++++++++++++++-------------- 3 files changed, 155 insertions(+), 140 deletions(-) diff --git a/llm/src/llm/core/head_attention.py b/llm/src/llm/core/head_attention.py index 410f499..194c706 100644 --- a/llm/src/llm/core/head_attention.py +++ b/llm/src/llm/core/head_attention.py @@ -89,10 +89,16 @@ class HeadAttention(nn.Module): q = self._q(x) # [B, T, hs] v = self._v(x) # [B, T, hs] + start_pos = 0 + if cache is not None: + k_cache, v_cache = cache + cache_len = k_cache.shape[1] + start_pos = cache_len + if self._rope is not None: # ✅ Применяем RoPE к Q и K (НЕ к V!) - q = self._rope(q) # [B, T, hs] - k = self._rope(k) # [B, T, hs] + q = self._rope(q, start_pos=start_pos) # [B, T, hs] + k = self._rope(k, start_pos=start_pos) # [B, T, hs] if cache is not None: k_cache, v_cache = cache diff --git a/llm/src/llm/core/rope.py b/llm/src/llm/core/rope.py index ebf188a..7a8801c 100644 --- a/llm/src/llm/core/rope.py +++ b/llm/src/llm/core/rope.py @@ -68,7 +68,7 @@ class RoPE(nn.Module): self.register_buffer("cos_matrix", torch.cos(freq_matrix)) self.register_buffer("sin_matrix", torch.sin(freq_matrix)) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: """ Применение ротационного позиционного кодирования к входному тензору. @@ -83,12 +83,12 @@ class RoPE(nn.Module): 2. Применение вращения через синусы и косинусы 3. Объединение компонент обратно """ - seq_len = x.size(1) + batch_size, seq_len, emb_size = x.shape # Берем нужную часть матриц и приводим к типу x - cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] - sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2] - + cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] + sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2] + # Разделяем на четные и нечетные компоненты x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2] x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2] diff --git a/notebooks/llama.ipynb b/notebooks/llama.ipynb index 34e487f..fedba8e 100644 --- a/notebooks/llama.ipynb +++ b/notebooks/llama.ipynb @@ -78,7 +78,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 19, "id": "873704be", "metadata": {}, "outputs": [], @@ -149,7 +149,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 20, "id": "0484cf77", "metadata": {}, "outputs": [], @@ -173,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 21, "id": "74ca39ba", "metadata": {}, "outputs": [], @@ -432,7 +432,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 22, "id": "02c300f9", "metadata": {}, "outputs": [ @@ -480,7 +480,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 23, "id": "ac072b9b", "metadata": {}, "outputs": [ @@ -515,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 24, "id": "4fe79d02", "metadata": {}, "outputs": [], @@ -543,11 +543,12 @@ " self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n", "\n", "\n", - " def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n", - " seq_len = x.size(1)\n", + " def forward(self, x: torch.Tensor, start_pos: int = 0): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n", + " batch_size, seq_len, emb_size = x.shape\n", + "\n", " # Берем нужную часть матриц и приводим к типу x\n", - " cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n", - " sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n", + " cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n", + " sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n", " \n", "\n", " # Разделяем на четные и нечетные\n", @@ -580,7 +581,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 25, "id": "fe1274b1", "metadata": {}, "outputs": [], @@ -630,7 +631,9 @@ " self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n", "\n", " def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n", - " seq_len = x.shape[1]\n", + " #seq_len = x.shape[1]\n", + " batch_size, seq_len, emb_size = x.shape\n", + "\n", " if seq_len > self._max_seq_len:\n", " raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n", "\n", @@ -638,9 +641,15 @@ " q = self._q(x) # [B, T, hs]\n", " v = self._v(x) # [B, T, hs]\n", "\n", + " start_pos = 0\n", + " if cache is not None:\n", + " k_cache, v_cache = cache\n", + " cache_len = k_cache.shape[1]\n", + " start_pos = cache_len\n", + "\n", " # ✅ Применяем RoPE к Q и K (НЕ к V!)\n", - " q = self._rope(q) # [B, T, hs]\n", - " k = self._rope(k) # [B, T, hs]\n", + " q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n", + " k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n", "\n", " if cache is not None:\n", " k_cache, v_cache = cache\n", @@ -1031,7 +1040,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 26, "id": "078ed4ce", "metadata": {}, "outputs": [], @@ -1177,7 +1186,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 27, "id": "0258c483", "metadata": {}, "outputs": [], @@ -1212,7 +1221,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 28, "id": "f8f06cf7", "metadata": {}, "outputs": [], @@ -1275,7 +1284,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 29, "id": "0c753d52", "metadata": {}, "outputs": [ @@ -1284,106 +1293,106 @@ "output_type": "stream", "text": [ "Dataset length: 20\n", - "Epoch 1/100, Loss: 3.9063\n", - "Epoch 2/100, Loss: 1.6363\n", - "Epoch 3/100, Loss: 0.6436\n", - "Epoch 4/100, Loss: 0.3222\n", - "Epoch 5/100, Loss: 0.2031\n", - "Epoch 6/100, Loss: 0.1460\n", - "Epoch 7/100, Loss: 0.1119\n", - "Epoch 8/100, Loss: 0.0978\n", - "Epoch 9/100, Loss: 0.0925\n", - "Epoch 10/100, Loss: 0.0757\n", - "Epoch 11/100, Loss: 0.0809\n", - "Epoch 12/100, Loss: 0.0763\n", - "Epoch 13/100, Loss: 0.0671\n", - "Epoch 14/100, Loss: 0.0681\n", - "Epoch 15/100, Loss: 0.0582\n", - "Epoch 16/100, Loss: 0.0655\n", - "Epoch 17/100, Loss: 0.0751\n", - "Epoch 18/100, Loss: 0.0656\n", - "Epoch 19/100, Loss: 0.0625\n", - "Epoch 20/100, Loss: 0.0593\n", - "Epoch 21/100, Loss: 0.0678\n", - "Epoch 22/100, Loss: 0.0652\n", - "Epoch 23/100, Loss: 0.0644\n", - "Epoch 24/100, Loss: 0.0542\n", - "Epoch 25/100, Loss: 0.0627\n", - "Epoch 26/100, Loss: 0.0596\n", - "Epoch 27/100, Loss: 0.0639\n", - "Epoch 28/100, Loss: 0.0536\n", - "Epoch 29/100, Loss: 0.0606\n", - "Epoch 30/100, Loss: 0.0565\n", - "Epoch 31/100, Loss: 0.0622\n", - "Epoch 32/100, Loss: 0.0581\n", - "Epoch 33/100, Loss: 0.0492\n", - "Epoch 34/100, Loss: 0.0553\n", - "Epoch 35/100, Loss: 0.0526\n", - "Epoch 36/100, Loss: 0.0527\n", + "Epoch 1/100, Loss: 3.7312\n", + "Epoch 2/100, Loss: 1.5022\n", + "Epoch 3/100, Loss: 0.5940\n", + "Epoch 4/100, Loss: 0.3368\n", + "Epoch 5/100, Loss: 0.2121\n", + "Epoch 6/100, Loss: 0.1507\n", + "Epoch 7/100, Loss: 0.1095\n", + "Epoch 8/100, Loss: 0.1023\n", + "Epoch 9/100, Loss: 0.0831\n", + "Epoch 10/100, Loss: 0.0805\n", + "Epoch 11/100, Loss: 0.0789\n", + "Epoch 12/100, Loss: 0.0802\n", + "Epoch 13/100, Loss: 0.0758\n", + "Epoch 14/100, Loss: 0.0676\n", + "Epoch 15/100, Loss: 0.0655\n", + "Epoch 16/100, Loss: 0.0701\n", + "Epoch 17/100, Loss: 0.0631\n", + "Epoch 18/100, Loss: 0.0691\n", + "Epoch 19/100, Loss: 0.0687\n", + "Epoch 20/100, Loss: 0.0623\n", + "Epoch 21/100, Loss: 0.0574\n", + "Epoch 22/100, Loss: 0.0603\n", + "Epoch 23/100, Loss: 0.0592\n", + "Epoch 24/100, Loss: 0.0587\n", + "Epoch 25/100, Loss: 0.0598\n", + "Epoch 26/100, Loss: 0.0589\n", + "Epoch 27/100, Loss: 0.0589\n", + "Epoch 28/100, Loss: 0.0546\n", + "Epoch 29/100, Loss: 0.0570\n", + "Epoch 30/100, Loss: 0.0673\n", + "Epoch 31/100, Loss: 0.0601\n", + "Epoch 32/100, Loss: 0.0702\n", + "Epoch 33/100, Loss: 0.0528\n", + "Epoch 34/100, Loss: 0.0508\n", + "Epoch 35/100, Loss: 0.0522\n", + "Epoch 36/100, Loss: 0.0537\n", "Epoch 37/100, Loss: 0.0570\n", - "Epoch 38/100, Loss: 0.0482\n", - "Epoch 39/100, Loss: 0.0553\n", - "Epoch 40/100, Loss: 0.0444\n", - "Epoch 41/100, Loss: 0.0602\n", - "Epoch 42/100, Loss: 0.0599\n", - "Epoch 43/100, Loss: 0.0598\n", - "Epoch 44/100, Loss: 0.0572\n", - "Epoch 45/100, Loss: 0.0551\n", - "Epoch 46/100, Loss: 0.0577\n", - "Epoch 47/100, Loss: 0.0527\n", - "Epoch 48/100, Loss: 0.0466\n", - "Epoch 49/100, Loss: 0.0551\n", - "Epoch 50/100, Loss: 0.0517\n", - "Epoch 51/100, Loss: 0.0477\n", - "Epoch 52/100, Loss: 0.0539\n", - "Epoch 53/100, Loss: 0.0478\n", - "Epoch 54/100, Loss: 0.0539\n", - "Epoch 55/100, Loss: 0.0435\n", - "Epoch 56/100, Loss: 0.0471\n", - "Epoch 57/100, Loss: 0.0461\n", - "Epoch 58/100, Loss: 0.0452\n", - "Epoch 59/100, Loss: 0.0507\n", - "Epoch 60/100, Loss: 0.0481\n", - "Epoch 61/100, Loss: 0.0398\n", - "Epoch 62/100, Loss: 0.0535\n", - "Epoch 63/100, Loss: 0.0503\n", - "Epoch 64/100, Loss: 0.0504\n", - "Epoch 65/100, Loss: 0.0473\n", - "Epoch 66/100, Loss: 0.0553\n", - "Epoch 67/100, Loss: 0.0514\n", - "Epoch 68/100, Loss: 0.0450\n", - "Epoch 69/100, Loss: 0.0488\n", - "Epoch 70/100, Loss: 0.0414\n", - "Epoch 71/100, Loss: 0.0413\n", - "Epoch 72/100, Loss: 0.0473\n", - "Epoch 73/100, Loss: 0.0530\n", - "Epoch 74/100, Loss: 0.0482\n", - "Epoch 75/100, Loss: 0.0477\n", - "Epoch 76/100, Loss: 0.0483\n", - "Epoch 77/100, Loss: 0.0452\n", - "Epoch 78/100, Loss: 0.0452\n", - "Epoch 79/100, Loss: 0.0474\n", - "Epoch 80/100, Loss: 0.0483\n", - "Epoch 81/100, Loss: 0.0522\n", - "Epoch 82/100, Loss: 0.0453\n", - "Epoch 83/100, Loss: 0.0436\n", - "Epoch 84/100, Loss: 0.0452\n", - "Epoch 85/100, Loss: 0.0523\n", - "Epoch 86/100, Loss: 0.0446\n", - "Epoch 87/100, Loss: 0.0475\n", - "Epoch 88/100, Loss: 0.0503\n", - "Epoch 89/100, Loss: 0.0484\n", - "Epoch 90/100, Loss: 0.0456\n", - "Epoch 91/100, Loss: 0.0433\n", - "Epoch 92/100, Loss: 0.0458\n", - "Epoch 93/100, Loss: 0.0461\n", - "Epoch 94/100, Loss: 0.0448\n", - "Epoch 95/100, Loss: 0.0432\n", - "Epoch 96/100, Loss: 0.0456\n", - "Epoch 97/100, Loss: 0.0470\n", - "Epoch 98/100, Loss: 0.0470\n", - "Epoch 99/100, Loss: 0.0467\n", - "Epoch 100/100, Loss: 0.0471\n" + "Epoch 38/100, Loss: 0.0580\n", + "Epoch 39/100, Loss: 0.0445\n", + "Epoch 40/100, Loss: 0.0516\n", + "Epoch 41/100, Loss: 0.0518\n", + "Epoch 42/100, Loss: 0.0545\n", + "Epoch 43/100, Loss: 0.0466\n", + "Epoch 44/100, Loss: 0.0523\n", + "Epoch 45/100, Loss: 0.0523\n", + "Epoch 46/100, Loss: 0.0547\n", + "Epoch 47/100, Loss: 0.0497\n", + "Epoch 48/100, Loss: 0.0512\n", + "Epoch 49/100, Loss: 0.0481\n", + "Epoch 50/100, Loss: 0.0498\n", + "Epoch 51/100, Loss: 0.0672\n", + "Epoch 52/100, Loss: 0.0530\n", + "Epoch 53/100, Loss: 0.0562\n", + "Epoch 54/100, Loss: 0.0536\n", + "Epoch 55/100, Loss: 0.0482\n", + "Epoch 56/100, Loss: 0.0438\n", + "Epoch 57/100, Loss: 0.0467\n", + "Epoch 58/100, Loss: 0.0501\n", + "Epoch 59/100, Loss: 0.0445\n", + "Epoch 60/100, Loss: 0.0471\n", + "Epoch 61/100, Loss: 0.0502\n", + "Epoch 62/100, Loss: 0.0474\n", + "Epoch 63/100, Loss: 0.0420\n", + "Epoch 64/100, Loss: 0.0541\n", + "Epoch 65/100, Loss: 0.0491\n", + "Epoch 66/100, Loss: 0.0489\n", + "Epoch 67/100, Loss: 0.0498\n", + "Epoch 68/100, Loss: 0.0511\n", + "Epoch 69/100, Loss: 0.0463\n", + "Epoch 70/100, Loss: 0.0480\n", + "Epoch 71/100, Loss: 0.0460\n", + "Epoch 72/100, Loss: 0.0533\n", + "Epoch 73/100, Loss: 0.0515\n", + "Epoch 74/100, Loss: 0.0419\n", + "Epoch 75/100, Loss: 0.0491\n", + "Epoch 76/100, Loss: 0.0471\n", + "Epoch 77/100, Loss: 0.0479\n", + "Epoch 78/100, Loss: 0.0444\n", + "Epoch 79/100, Loss: 0.0520\n", + "Epoch 80/100, Loss: 0.0520\n", + "Epoch 81/100, Loss: 0.0489\n", + "Epoch 82/100, Loss: 0.0467\n", + "Epoch 83/100, Loss: 0.0464\n", + "Epoch 84/100, Loss: 0.0451\n", + "Epoch 85/100, Loss: 0.0526\n", + "Epoch 86/100, Loss: 0.0501\n", + "Epoch 87/100, Loss: 0.0438\n", + "Epoch 88/100, Loss: 0.0476\n", + "Epoch 89/100, Loss: 0.0442\n", + "Epoch 90/100, Loss: 0.0432\n", + "Epoch 91/100, Loss: 0.0469\n", + "Epoch 92/100, Loss: 0.0494\n", + "Epoch 93/100, Loss: 0.0487\n", + "Epoch 94/100, Loss: 0.0445\n", + "Epoch 95/100, Loss: 0.0442\n", + "Epoch 96/100, Loss: 0.0417\n", + "Epoch 97/100, Loss: 0.0441\n", + "Epoch 98/100, Loss: 0.0417\n", + "Epoch 99/100, Loss: 0.0435\n", + "Epoch 100/100, Loss: 0.0433\n" ] }, { @@ -1425,7 +1434,7 @@ ")" ] }, - "execution_count": 11, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1486,7 +1495,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 30, "id": "8efb1396", "metadata": {}, "outputs": [], @@ -1519,7 +1528,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 31, "id": "b4a1c9d9", "metadata": {}, "outputs": [ @@ -1527,16 +1536,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "Fine-tune Epoch 1/10, Loss: 4.8250\n", - "Fine-tune Epoch 2/10, Loss: 2.6400\n", - "Fine-tune Epoch 3/10, Loss: 1.6920\n", - "Fine-tune Epoch 4/10, Loss: 1.2630\n", - "Fine-tune Epoch 5/10, Loss: 1.0183\n", - "Fine-tune Epoch 6/10, Loss: 0.8237\n", - "Fine-tune Epoch 7/10, Loss: 0.6869\n", - "Fine-tune Epoch 8/10, Loss: 0.5854\n", - "Fine-tune Epoch 9/10, Loss: 0.5155\n", - "Fine-tune Epoch 10/10, Loss: 0.4489\n" + "Fine-tune Epoch 1/10, Loss: 4.7966\n", + "Fine-tune Epoch 2/10, Loss: 2.6961\n", + "Fine-tune Epoch 3/10, Loss: 1.7293\n", + "Fine-tune Epoch 4/10, Loss: 1.2899\n", + "Fine-tune Epoch 5/10, Loss: 1.0189\n", + "Fine-tune Epoch 6/10, Loss: 0.8710\n", + "Fine-tune Epoch 7/10, Loss: 0.7198\n", + "Fine-tune Epoch 8/10, Loss: 0.6079\n", + "Fine-tune Epoch 9/10, Loss: 0.5297\n", + "Fine-tune Epoch 10/10, Loss: 0.4712\n" ] } ], @@ -1565,7 +1574,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 32, "id": "d77b8ff5", "metadata": {}, "outputs": [], @@ -1580,7 +1589,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 33, "id": "79fea720", "metadata": {}, "outputs": [ @@ -1588,7 +1597,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Deep learning ena les selelf les te\n" + "Deep learning ena les te autt es re\n" ] } ],