mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем
- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша. - В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша. - Обновлена сигнатура и логика метода RoPE.forward. - Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы. BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
This commit is contained in:
@@ -89,10 +89,16 @@ class HeadAttention(nn.Module):
|
|||||||
q = self._q(x) # [B, T, hs]
|
q = self._q(x) # [B, T, hs]
|
||||||
v = self._v(x) # [B, T, hs]
|
v = self._v(x) # [B, T, hs]
|
||||||
|
|
||||||
|
start_pos = 0
|
||||||
|
if cache is not None:
|
||||||
|
k_cache, v_cache = cache
|
||||||
|
cache_len = k_cache.shape[1]
|
||||||
|
start_pos = cache_len
|
||||||
|
|
||||||
if self._rope is not None:
|
if self._rope is not None:
|
||||||
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
# ✅ Применяем RoPE к Q и K (НЕ к V!)
|
||||||
q = self._rope(q) # [B, T, hs]
|
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
|
||||||
k = self._rope(k) # [B, T, hs]
|
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
k_cache, v_cache = cache
|
k_cache, v_cache = cache
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class RoPE(nn.Module):
|
|||||||
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
|
||||||
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Применение ротационного позиционного кодирования к входному тензору.
|
Применение ротационного позиционного кодирования к входному тензору.
|
||||||
|
|
||||||
@@ -83,11 +83,11 @@ class RoPE(nn.Module):
|
|||||||
2. Применение вращения через синусы и косинусы
|
2. Применение вращения через синусы и косинусы
|
||||||
3. Объединение компонент обратно
|
3. Объединение компонент обратно
|
||||||
"""
|
"""
|
||||||
seq_len = x.size(1)
|
batch_size, seq_len, emb_size = x.shape
|
||||||
|
|
||||||
# Берем нужную часть матриц и приводим к типу x
|
# Берем нужную часть матриц и приводим к типу x
|
||||||
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
|
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
|
||||||
|
|
||||||
# Разделяем на четные и нечетные компоненты
|
# Разделяем на четные и нечетные компоненты
|
||||||
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
|
||||||
|
|||||||
@@ -78,7 +78,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 19,
|
||||||
"id": "873704be",
|
"id": "873704be",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -149,7 +149,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 20,
|
||||||
"id": "0484cf77",
|
"id": "0484cf77",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -173,7 +173,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 21,
|
||||||
"id": "74ca39ba",
|
"id": "74ca39ba",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -432,7 +432,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 22,
|
||||||
"id": "02c300f9",
|
"id": "02c300f9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -480,7 +480,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 23,
|
||||||
"id": "ac072b9b",
|
"id": "ac072b9b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -515,7 +515,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 24,
|
||||||
"id": "4fe79d02",
|
"id": "4fe79d02",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -543,11 +543,12 @@
|
|||||||
" self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n",
|
" self.register_buffer('sin_matrix', torch.sin(freq_matrix))\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x: torch.Tensor): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n",
|
" def forward(self, x: torch.Tensor, start_pos: int = 0): # Получает на вход тензор x (тип float) размером [batch_size × seq_len × head_size]\n",
|
||||||
" seq_len = x.size(1)\n",
|
" batch_size, seq_len, emb_size = x.shape\n",
|
||||||
|
"\n",
|
||||||
" # Берем нужную часть матриц и приводим к типу x\n",
|
" # Берем нужную часть матриц и приводим к типу x\n",
|
||||||
" cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
" cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
||||||
" sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
" sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n",
|
||||||
" \n",
|
" \n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Разделяем на четные и нечетные\n",
|
" # Разделяем на четные и нечетные\n",
|
||||||
@@ -580,7 +581,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 25,
|
||||||
"id": "fe1274b1",
|
"id": "fe1274b1",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -630,7 +631,9 @@
|
|||||||
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
|
" self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())\n",
|
||||||
"\n",
|
"\n",
|
||||||
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
|
" def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:\n",
|
||||||
" seq_len = x.shape[1]\n",
|
" #seq_len = x.shape[1]\n",
|
||||||
|
" batch_size, seq_len, emb_size = x.shape\n",
|
||||||
|
"\n",
|
||||||
" if seq_len > self._max_seq_len:\n",
|
" if seq_len > self._max_seq_len:\n",
|
||||||
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
|
" raise ValueError(f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\")\n",
|
||||||
"\n",
|
"\n",
|
||||||
@@ -638,9 +641,15 @@
|
|||||||
" q = self._q(x) # [B, T, hs]\n",
|
" q = self._q(x) # [B, T, hs]\n",
|
||||||
" v = self._v(x) # [B, T, hs]\n",
|
" v = self._v(x) # [B, T, hs]\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
" start_pos = 0\n",
|
||||||
|
" if cache is not None:\n",
|
||||||
|
" k_cache, v_cache = cache\n",
|
||||||
|
" cache_len = k_cache.shape[1]\n",
|
||||||
|
" start_pos = cache_len\n",
|
||||||
|
"\n",
|
||||||
" # ✅ Применяем RoPE к Q и K (НЕ к V!)\n",
|
" # ✅ Применяем RoPE к Q и K (НЕ к V!)\n",
|
||||||
" q = self._rope(q) # [B, T, hs]\n",
|
" q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n",
|
||||||
" k = self._rope(k) # [B, T, hs]\n",
|
" k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n",
|
||||||
"\n",
|
"\n",
|
||||||
" if cache is not None:\n",
|
" if cache is not None:\n",
|
||||||
" k_cache, v_cache = cache\n",
|
" k_cache, v_cache = cache\n",
|
||||||
@@ -1031,7 +1040,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 26,
|
||||||
"id": "078ed4ce",
|
"id": "078ed4ce",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1177,7 +1186,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 27,
|
||||||
"id": "0258c483",
|
"id": "0258c483",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1212,7 +1221,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 28,
|
||||||
"id": "f8f06cf7",
|
"id": "f8f06cf7",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1275,7 +1284,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 29,
|
||||||
"id": "0c753d52",
|
"id": "0c753d52",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1284,106 +1293,106 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Dataset length: 20\n",
|
"Dataset length: 20\n",
|
||||||
"Epoch 1/100, Loss: 3.9063\n",
|
"Epoch 1/100, Loss: 3.7312\n",
|
||||||
"Epoch 2/100, Loss: 1.6363\n",
|
"Epoch 2/100, Loss: 1.5022\n",
|
||||||
"Epoch 3/100, Loss: 0.6436\n",
|
"Epoch 3/100, Loss: 0.5940\n",
|
||||||
"Epoch 4/100, Loss: 0.3222\n",
|
"Epoch 4/100, Loss: 0.3368\n",
|
||||||
"Epoch 5/100, Loss: 0.2031\n",
|
"Epoch 5/100, Loss: 0.2121\n",
|
||||||
"Epoch 6/100, Loss: 0.1460\n",
|
"Epoch 6/100, Loss: 0.1507\n",
|
||||||
"Epoch 7/100, Loss: 0.1119\n",
|
"Epoch 7/100, Loss: 0.1095\n",
|
||||||
"Epoch 8/100, Loss: 0.0978\n",
|
"Epoch 8/100, Loss: 0.1023\n",
|
||||||
"Epoch 9/100, Loss: 0.0925\n",
|
"Epoch 9/100, Loss: 0.0831\n",
|
||||||
"Epoch 10/100, Loss: 0.0757\n",
|
"Epoch 10/100, Loss: 0.0805\n",
|
||||||
"Epoch 11/100, Loss: 0.0809\n",
|
"Epoch 11/100, Loss: 0.0789\n",
|
||||||
"Epoch 12/100, Loss: 0.0763\n",
|
"Epoch 12/100, Loss: 0.0802\n",
|
||||||
"Epoch 13/100, Loss: 0.0671\n",
|
"Epoch 13/100, Loss: 0.0758\n",
|
||||||
"Epoch 14/100, Loss: 0.0681\n",
|
"Epoch 14/100, Loss: 0.0676\n",
|
||||||
"Epoch 15/100, Loss: 0.0582\n",
|
"Epoch 15/100, Loss: 0.0655\n",
|
||||||
"Epoch 16/100, Loss: 0.0655\n",
|
"Epoch 16/100, Loss: 0.0701\n",
|
||||||
"Epoch 17/100, Loss: 0.0751\n",
|
"Epoch 17/100, Loss: 0.0631\n",
|
||||||
"Epoch 18/100, Loss: 0.0656\n",
|
"Epoch 18/100, Loss: 0.0691\n",
|
||||||
"Epoch 19/100, Loss: 0.0625\n",
|
"Epoch 19/100, Loss: 0.0687\n",
|
||||||
"Epoch 20/100, Loss: 0.0593\n",
|
"Epoch 20/100, Loss: 0.0623\n",
|
||||||
"Epoch 21/100, Loss: 0.0678\n",
|
"Epoch 21/100, Loss: 0.0574\n",
|
||||||
"Epoch 22/100, Loss: 0.0652\n",
|
"Epoch 22/100, Loss: 0.0603\n",
|
||||||
"Epoch 23/100, Loss: 0.0644\n",
|
"Epoch 23/100, Loss: 0.0592\n",
|
||||||
"Epoch 24/100, Loss: 0.0542\n",
|
"Epoch 24/100, Loss: 0.0587\n",
|
||||||
"Epoch 25/100, Loss: 0.0627\n",
|
"Epoch 25/100, Loss: 0.0598\n",
|
||||||
"Epoch 26/100, Loss: 0.0596\n",
|
"Epoch 26/100, Loss: 0.0589\n",
|
||||||
"Epoch 27/100, Loss: 0.0639\n",
|
"Epoch 27/100, Loss: 0.0589\n",
|
||||||
"Epoch 28/100, Loss: 0.0536\n",
|
"Epoch 28/100, Loss: 0.0546\n",
|
||||||
"Epoch 29/100, Loss: 0.0606\n",
|
"Epoch 29/100, Loss: 0.0570\n",
|
||||||
"Epoch 30/100, Loss: 0.0565\n",
|
"Epoch 30/100, Loss: 0.0673\n",
|
||||||
"Epoch 31/100, Loss: 0.0622\n",
|
"Epoch 31/100, Loss: 0.0601\n",
|
||||||
"Epoch 32/100, Loss: 0.0581\n",
|
"Epoch 32/100, Loss: 0.0702\n",
|
||||||
"Epoch 33/100, Loss: 0.0492\n",
|
"Epoch 33/100, Loss: 0.0528\n",
|
||||||
"Epoch 34/100, Loss: 0.0553\n",
|
"Epoch 34/100, Loss: 0.0508\n",
|
||||||
"Epoch 35/100, Loss: 0.0526\n",
|
"Epoch 35/100, Loss: 0.0522\n",
|
||||||
"Epoch 36/100, Loss: 0.0527\n",
|
"Epoch 36/100, Loss: 0.0537\n",
|
||||||
"Epoch 37/100, Loss: 0.0570\n",
|
"Epoch 37/100, Loss: 0.0570\n",
|
||||||
"Epoch 38/100, Loss: 0.0482\n",
|
"Epoch 38/100, Loss: 0.0580\n",
|
||||||
"Epoch 39/100, Loss: 0.0553\n",
|
"Epoch 39/100, Loss: 0.0445\n",
|
||||||
"Epoch 40/100, Loss: 0.0444\n",
|
"Epoch 40/100, Loss: 0.0516\n",
|
||||||
"Epoch 41/100, Loss: 0.0602\n",
|
"Epoch 41/100, Loss: 0.0518\n",
|
||||||
"Epoch 42/100, Loss: 0.0599\n",
|
"Epoch 42/100, Loss: 0.0545\n",
|
||||||
"Epoch 43/100, Loss: 0.0598\n",
|
"Epoch 43/100, Loss: 0.0466\n",
|
||||||
"Epoch 44/100, Loss: 0.0572\n",
|
"Epoch 44/100, Loss: 0.0523\n",
|
||||||
"Epoch 45/100, Loss: 0.0551\n",
|
"Epoch 45/100, Loss: 0.0523\n",
|
||||||
"Epoch 46/100, Loss: 0.0577\n",
|
"Epoch 46/100, Loss: 0.0547\n",
|
||||||
"Epoch 47/100, Loss: 0.0527\n",
|
"Epoch 47/100, Loss: 0.0497\n",
|
||||||
"Epoch 48/100, Loss: 0.0466\n",
|
"Epoch 48/100, Loss: 0.0512\n",
|
||||||
"Epoch 49/100, Loss: 0.0551\n",
|
"Epoch 49/100, Loss: 0.0481\n",
|
||||||
"Epoch 50/100, Loss: 0.0517\n",
|
"Epoch 50/100, Loss: 0.0498\n",
|
||||||
"Epoch 51/100, Loss: 0.0477\n",
|
"Epoch 51/100, Loss: 0.0672\n",
|
||||||
"Epoch 52/100, Loss: 0.0539\n",
|
"Epoch 52/100, Loss: 0.0530\n",
|
||||||
"Epoch 53/100, Loss: 0.0478\n",
|
"Epoch 53/100, Loss: 0.0562\n",
|
||||||
"Epoch 54/100, Loss: 0.0539\n",
|
"Epoch 54/100, Loss: 0.0536\n",
|
||||||
"Epoch 55/100, Loss: 0.0435\n",
|
"Epoch 55/100, Loss: 0.0482\n",
|
||||||
"Epoch 56/100, Loss: 0.0471\n",
|
"Epoch 56/100, Loss: 0.0438\n",
|
||||||
"Epoch 57/100, Loss: 0.0461\n",
|
"Epoch 57/100, Loss: 0.0467\n",
|
||||||
"Epoch 58/100, Loss: 0.0452\n",
|
"Epoch 58/100, Loss: 0.0501\n",
|
||||||
"Epoch 59/100, Loss: 0.0507\n",
|
"Epoch 59/100, Loss: 0.0445\n",
|
||||||
"Epoch 60/100, Loss: 0.0481\n",
|
"Epoch 60/100, Loss: 0.0471\n",
|
||||||
"Epoch 61/100, Loss: 0.0398\n",
|
"Epoch 61/100, Loss: 0.0502\n",
|
||||||
"Epoch 62/100, Loss: 0.0535\n",
|
"Epoch 62/100, Loss: 0.0474\n",
|
||||||
"Epoch 63/100, Loss: 0.0503\n",
|
"Epoch 63/100, Loss: 0.0420\n",
|
||||||
"Epoch 64/100, Loss: 0.0504\n",
|
"Epoch 64/100, Loss: 0.0541\n",
|
||||||
"Epoch 65/100, Loss: 0.0473\n",
|
"Epoch 65/100, Loss: 0.0491\n",
|
||||||
"Epoch 66/100, Loss: 0.0553\n",
|
"Epoch 66/100, Loss: 0.0489\n",
|
||||||
"Epoch 67/100, Loss: 0.0514\n",
|
"Epoch 67/100, Loss: 0.0498\n",
|
||||||
"Epoch 68/100, Loss: 0.0450\n",
|
"Epoch 68/100, Loss: 0.0511\n",
|
||||||
"Epoch 69/100, Loss: 0.0488\n",
|
"Epoch 69/100, Loss: 0.0463\n",
|
||||||
"Epoch 70/100, Loss: 0.0414\n",
|
"Epoch 70/100, Loss: 0.0480\n",
|
||||||
"Epoch 71/100, Loss: 0.0413\n",
|
"Epoch 71/100, Loss: 0.0460\n",
|
||||||
"Epoch 72/100, Loss: 0.0473\n",
|
"Epoch 72/100, Loss: 0.0533\n",
|
||||||
"Epoch 73/100, Loss: 0.0530\n",
|
"Epoch 73/100, Loss: 0.0515\n",
|
||||||
"Epoch 74/100, Loss: 0.0482\n",
|
"Epoch 74/100, Loss: 0.0419\n",
|
||||||
"Epoch 75/100, Loss: 0.0477\n",
|
"Epoch 75/100, Loss: 0.0491\n",
|
||||||
"Epoch 76/100, Loss: 0.0483\n",
|
"Epoch 76/100, Loss: 0.0471\n",
|
||||||
"Epoch 77/100, Loss: 0.0452\n",
|
"Epoch 77/100, Loss: 0.0479\n",
|
||||||
"Epoch 78/100, Loss: 0.0452\n",
|
"Epoch 78/100, Loss: 0.0444\n",
|
||||||
"Epoch 79/100, Loss: 0.0474\n",
|
"Epoch 79/100, Loss: 0.0520\n",
|
||||||
"Epoch 80/100, Loss: 0.0483\n",
|
"Epoch 80/100, Loss: 0.0520\n",
|
||||||
"Epoch 81/100, Loss: 0.0522\n",
|
"Epoch 81/100, Loss: 0.0489\n",
|
||||||
"Epoch 82/100, Loss: 0.0453\n",
|
"Epoch 82/100, Loss: 0.0467\n",
|
||||||
"Epoch 83/100, Loss: 0.0436\n",
|
"Epoch 83/100, Loss: 0.0464\n",
|
||||||
"Epoch 84/100, Loss: 0.0452\n",
|
"Epoch 84/100, Loss: 0.0451\n",
|
||||||
"Epoch 85/100, Loss: 0.0523\n",
|
"Epoch 85/100, Loss: 0.0526\n",
|
||||||
"Epoch 86/100, Loss: 0.0446\n",
|
"Epoch 86/100, Loss: 0.0501\n",
|
||||||
"Epoch 87/100, Loss: 0.0475\n",
|
"Epoch 87/100, Loss: 0.0438\n",
|
||||||
"Epoch 88/100, Loss: 0.0503\n",
|
"Epoch 88/100, Loss: 0.0476\n",
|
||||||
"Epoch 89/100, Loss: 0.0484\n",
|
"Epoch 89/100, Loss: 0.0442\n",
|
||||||
"Epoch 90/100, Loss: 0.0456\n",
|
"Epoch 90/100, Loss: 0.0432\n",
|
||||||
"Epoch 91/100, Loss: 0.0433\n",
|
"Epoch 91/100, Loss: 0.0469\n",
|
||||||
"Epoch 92/100, Loss: 0.0458\n",
|
"Epoch 92/100, Loss: 0.0494\n",
|
||||||
"Epoch 93/100, Loss: 0.0461\n",
|
"Epoch 93/100, Loss: 0.0487\n",
|
||||||
"Epoch 94/100, Loss: 0.0448\n",
|
"Epoch 94/100, Loss: 0.0445\n",
|
||||||
"Epoch 95/100, Loss: 0.0432\n",
|
"Epoch 95/100, Loss: 0.0442\n",
|
||||||
"Epoch 96/100, Loss: 0.0456\n",
|
"Epoch 96/100, Loss: 0.0417\n",
|
||||||
"Epoch 97/100, Loss: 0.0470\n",
|
"Epoch 97/100, Loss: 0.0441\n",
|
||||||
"Epoch 98/100, Loss: 0.0470\n",
|
"Epoch 98/100, Loss: 0.0417\n",
|
||||||
"Epoch 99/100, Loss: 0.0467\n",
|
"Epoch 99/100, Loss: 0.0435\n",
|
||||||
"Epoch 100/100, Loss: 0.0471\n"
|
"Epoch 100/100, Loss: 0.0433\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -1425,7 +1434,7 @@
|
|||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 11,
|
"execution_count": 29,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
@@ -1486,7 +1495,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 30,
|
||||||
"id": "8efb1396",
|
"id": "8efb1396",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1519,7 +1528,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 31,
|
||||||
"id": "b4a1c9d9",
|
"id": "b4a1c9d9",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1527,16 +1536,16 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Fine-tune Epoch 1/10, Loss: 4.8250\n",
|
"Fine-tune Epoch 1/10, Loss: 4.7966\n",
|
||||||
"Fine-tune Epoch 2/10, Loss: 2.6400\n",
|
"Fine-tune Epoch 2/10, Loss: 2.6961\n",
|
||||||
"Fine-tune Epoch 3/10, Loss: 1.6920\n",
|
"Fine-tune Epoch 3/10, Loss: 1.7293\n",
|
||||||
"Fine-tune Epoch 4/10, Loss: 1.2630\n",
|
"Fine-tune Epoch 4/10, Loss: 1.2899\n",
|
||||||
"Fine-tune Epoch 5/10, Loss: 1.0183\n",
|
"Fine-tune Epoch 5/10, Loss: 1.0189\n",
|
||||||
"Fine-tune Epoch 6/10, Loss: 0.8237\n",
|
"Fine-tune Epoch 6/10, Loss: 0.8710\n",
|
||||||
"Fine-tune Epoch 7/10, Loss: 0.6869\n",
|
"Fine-tune Epoch 7/10, Loss: 0.7198\n",
|
||||||
"Fine-tune Epoch 8/10, Loss: 0.5854\n",
|
"Fine-tune Epoch 8/10, Loss: 0.6079\n",
|
||||||
"Fine-tune Epoch 9/10, Loss: 0.5155\n",
|
"Fine-tune Epoch 9/10, Loss: 0.5297\n",
|
||||||
"Fine-tune Epoch 10/10, Loss: 0.4489\n"
|
"Fine-tune Epoch 10/10, Loss: 0.4712\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -1565,7 +1574,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 32,
|
||||||
"id": "d77b8ff5",
|
"id": "d77b8ff5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@@ -1580,7 +1589,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 33,
|
||||||
"id": "79fea720",
|
"id": "79fea720",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@@ -1588,7 +1597,7 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Deep learning ena les selelf les te\n"
|
"Deep learning ena les te autt es re\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user