mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
docs, logic: обновление документации и автодовосстановления обучения модели, актуализация index.md
This commit is contained in:
@@ -7,14 +7,17 @@ Callback-система для управления обучением GPT.
|
||||
- LRSchedulerCallback - регулировка learning rate
|
||||
"""
|
||||
|
||||
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/__init__.py
|
||||
from .callback import Callback
|
||||
from .early_stopping_callback import EarlyStoppingCallback
|
||||
from .lrs_scheduler_callback import LRSchedulerCallback
|
||||
from .model_checkpoint_callback import ModelCheckpointCallback
|
||||
from .resume_training_callback import ResumeTrainingCallback
|
||||
|
||||
__all__ = [
|
||||
'Callback',
|
||||
'EarlyStoppingCallback',
|
||||
'LRSchedulerCallback',
|
||||
'ModelCheckpointCallback'
|
||||
'LRSchedulerCallback',
|
||||
'ModelCheckpointCallback',
|
||||
'ResumeTrainingCallback'
|
||||
]
|
||||
@@ -12,6 +12,9 @@ class Callback:
|
||||
- on_batch_end - после обработки батча
|
||||
- on_epoch_end - в конце эпохи
|
||||
"""
|
||||
|
||||
def on_train_begin(self, model):
|
||||
pass
|
||||
|
||||
def on_epoch_begin(self, epoch, model):
|
||||
"""Вызывается перед началом эпохи.
|
||||
|
||||
@@ -14,8 +14,22 @@ class LRSchedulerCallback(Callback):
|
||||
def __init__(self, lr, decay=0.95):
|
||||
self.base_lr = lr
|
||||
self.decay = decay
|
||||
self.last_epoch = -1 # Добавляем отслеживание эпохи
|
||||
|
||||
def get_state(self):
|
||||
return {
|
||||
'base_lr': self.base_lr,
|
||||
'decay': self.decay,
|
||||
'last_epoch': self.last_epoch
|
||||
}
|
||||
|
||||
def set_state(self, state):
|
||||
self.base_lr = state['base_lr']
|
||||
self.decay = state['decay']
|
||||
self.last_epoch = state['last_epoch']
|
||||
|
||||
def on_epoch_begin(self, epoch, model):
|
||||
self.last_epoch = epoch # Сохраняем текущую эпоху
|
||||
new_lr = self.base_lr * (self.decay ** epoch)
|
||||
for param_group in model.optimizer.param_groups:
|
||||
param_group['lr'] = new_lr
|
||||
@@ -1,6 +1,7 @@
|
||||
from .callback import Callback
|
||||
import torch
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
class ModelCheckpointCallback(Callback):
|
||||
"""Сохраняет чекпоинты модели во время обучения.
|
||||
@@ -11,24 +12,64 @@ class ModelCheckpointCallback(Callback):
|
||||
|
||||
Args:
|
||||
save_dir (str): Директория для сохранения
|
||||
save_best_only (bool): Сохранять только лучшие модели
|
||||
save_best_only (bool): Если True, сохраняет только при улучшении loss
|
||||
save_freq (int): Сохранять каждые N эпох (default=1)
|
||||
monitor (str): Какой loss мониторить ('val' или 'train')
|
||||
"""
|
||||
def __init__(self, save_dir, save_best_only=True):
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
save_best_only: bool = True,
|
||||
save_freq: int = 1,
|
||||
monitor: str = 'val'):
|
||||
self.save_dir = save_dir
|
||||
self.save_best_only = save_best_only
|
||||
self.save_freq = save_freq
|
||||
self.monitor = monitor
|
||||
self.best_loss = float('inf')
|
||||
|
||||
def on_epoch_end(self, epoch, model, train_loss, val_loss):
|
||||
if not os.path.exists(self.save_dir):
|
||||
os.makedirs(self.save_dir)
|
||||
|
||||
current_loss = val_loss if val_loss else train_loss
|
||||
# Создаем директорию если её нет
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if not self.save_best_only or current_loss < self.best_loss:
|
||||
def on_epoch_end(self, epoch, model, train_loss, val_loss):
|
||||
# Решаем какой loss использовать для сравнения
|
||||
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
|
||||
|
||||
# Сохраняем по расписанию или при улучшении
|
||||
should_save = (
|
||||
(epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
|
||||
)
|
||||
|
||||
if should_save:
|
||||
self.best_loss = current_loss
|
||||
path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pt")
|
||||
checkpoint_path = os.path.join(
|
||||
self.save_dir,
|
||||
f"checkpoint_epoch_{epoch}.pt"
|
||||
)
|
||||
|
||||
# Собираем состояния всех callback'ов
|
||||
callback_states = {}
|
||||
if hasattr(model, '_callbacks'):
|
||||
for cb in model._callbacks:
|
||||
if hasattr(cb, 'get_state'):
|
||||
callback_states[cb.__class__.__name__] = cb.get_state()
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state': model.state_dict(),
|
||||
'loss': current_loss
|
||||
}, path)
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': model.optimizer.state_dict(),
|
||||
'train_loss': train_loss,
|
||||
'val_loss': val_loss,
|
||||
'best_loss': self.best_loss,
|
||||
'callback_states': callback_states,
|
||||
'config': {
|
||||
'vocab_size': model._vocab_size,
|
||||
'max_seq_len': model._max_seq_len,
|
||||
'emb_size': model._emb_size,
|
||||
'num_heads': model._num_heads,
|
||||
'head_size': model._head_size,
|
||||
'num_layers': model._num_layers
|
||||
}
|
||||
}, checkpoint_path)
|
||||
|
||||
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
|
||||
|
||||
53
simple_llm/transformer/callback/resume_training_callback.py
Normal file
53
simple_llm/transformer/callback/resume_training_callback.py
Normal file
@@ -0,0 +1,53 @@
|
||||
# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py
|
||||
import os
|
||||
import torch
|
||||
from typing import Optional
|
||||
from .callback import Callback
|
||||
|
||||
class ResumeTrainingCallback(Callback):
|
||||
"""Callback для восстановления обучения с последнего чекпоинта"""
|
||||
|
||||
def __init__(self, checkpoint_dir: str, resume: bool = True):
|
||||
"""
|
||||
Args:
|
||||
checkpoint_dir: Путь к директории с чекпоинтами
|
||||
resume: Флаг восстановления обучения (default=True)
|
||||
"""
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.resume = resume
|
||||
self.last_epoch = -1
|
||||
|
||||
def on_train_begin(self, model):
|
||||
if not self.resume:
|
||||
return
|
||||
|
||||
checkpoint_path = self._find_latest_checkpoint()
|
||||
if checkpoint_path:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
|
||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||
if not os.path.exists(self.checkpoint_dir):
|
||||
return None
|
||||
|
||||
checkpoints = [f for f in os.listdir(self.checkpoint_dir)
|
||||
if f.startswith('checkpoint_') and f.endswith('.pt')]
|
||||
|
||||
if not checkpoints:
|
||||
return None
|
||||
|
||||
# Сортируем по времени создания
|
||||
checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x)))
|
||||
return os.path.join(self.checkpoint_dir, checkpoints[-1])
|
||||
324
simple_llm_demo.ipynb
Normal file
324
simple_llm_demo.ipynb
Normal file
@@ -0,0 +1,324 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Демонстрация simple_llm\n",
|
||||
"## Полное руководство по установке и использованию"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 1. Установка и настройка\n",
|
||||
"\n",
|
||||
"### Клонирование репозитория:\n",
|
||||
"```bash\n",
|
||||
"git clone https://github.com/ваш_username/simple-llm.git\n",
|
||||
"cd simple-llm\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Установка зависимостей:\n",
|
||||
"```bash\n",
|
||||
"pip install -e .\n",
|
||||
"pip install torch tqdm\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Проверка структуры данных:\n",
|
||||
"```bash\n",
|
||||
"mkdir -p data/corpus/sample data/model data/tokenizer\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 2. Инициализация и проверка окружения"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import sys\n",
|
||||
"import os\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"# Проверка версии PyTorch\n",
|
||||
"print(f\"PyTorch version: {torch.__version__}\")\n",
|
||||
"\n",
|
||||
"# Добавление пути к библиотеке\n",
|
||||
"project_path = os.path.abspath('../simple-llm')\n",
|
||||
"sys.path.append(project_path)\n",
|
||||
"print(f\"Путь к проекту: {project_path}\")\n",
|
||||
"\n",
|
||||
"# Проверка модулей\n",
|
||||
"try:\n",
|
||||
" from simple_llm.tokenizer.bpe import BPETokenizer\n",
|
||||
" from simple_llm.data.get_data import load_text_corpus\n",
|
||||
" from simple_llm.transformer.gpt import GPT\n",
|
||||
" print(\"✓ Все модули успешно импортированы\")\n",
|
||||
"except ImportError as e:\n",
|
||||
" print(f\"✗ Ошибка: {e}\")\n",
|
||||
" print(\"Решение: выполните 'pip install -e .' из корня проекта\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3. Работа с токенизатором"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Инициализация токенизатора\n",
|
||||
"tokenizer = BPETokenizer()\n",
|
||||
"\n",
|
||||
"# Загрузка и обработка текста\n",
|
||||
"corpus_path = 'data/corpus/sample/'\n",
|
||||
"if os.path.exists(corpus_path):\n",
|
||||
" text = load_text_corpus(corpus_path)\n",
|
||||
" print(f\"Загружено текста: {len(text.split())} слов\")\n",
|
||||
" \n",
|
||||
" # Обучение токенизатора\n",
|
||||
" tokenizer.train(text, vocab_size=1000)\n",
|
||||
" print(f\"Токенизатор обучен, размер словаря: {tokenizer.vocab_size}\")\n",
|
||||
" \n",
|
||||
" # Тест токенизации\n",
|
||||
" test_phrase = \"Пример работы токенизатора\"\n",
|
||||
" tokens = tokenizer.encode(test_phrase)\n",
|
||||
" print(f\"Текст: {test_phrase}\")\n",
|
||||
" print(f\"Токены: {tokens}\")\n",
|
||||
" print(f\"Обратное преобразование: {tokenizer.decode(tokens)}\")\n",
|
||||
"else:\n",
|
||||
" print(f\"Директория {corpus_path} не содержит данных для обучения\")\n",
|
||||
" print(\"Добавьте текстовые файлы в формате .txt в эту директорию\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 4. Подготовка данных для обучения"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if 'text' in locals():\n",
|
||||
" # Токенизация всего корпуса\n",
|
||||
" all_tokens = tokenizer.encode(text)\n",
|
||||
" \n",
|
||||
" # Создание обучающих последовательностей\n",
|
||||
" seq_length = 64\n",
|
||||
" examples = []\n",
|
||||
" for i in range(0, len(all_tokens) - seq_length - 1, seq_length):\n",
|
||||
" input_seq = all_tokens[i:i+seq_length]\n",
|
||||
" target_seq = all_tokens[i+1:i+seq_length+1]\n",
|
||||
" examples.append((input_seq, target_seq))\n",
|
||||
" \n",
|
||||
" print(f\"Создано обучающих примеров: {len(examples)}\")\n",
|
||||
" print(f\"Размер последовательности: {seq_length} токенов\")\n",
|
||||
" print(f\"Пример входных данных: {examples[0][0][:10]}...\")\n",
|
||||
" print(f\"Пример целевых данных: {examples[0][1][:10]}...\")\n",
|
||||
"else:\n",
|
||||
" print(\"Невозможно подготовить данные: текст не загружен\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 5. Обучение модели GPT"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if 'examples' in locals() and len(examples) > 0:\n",
|
||||
" # Конфигурация модели\n",
|
||||
" config = {\n",
|
||||
" 'vocab_size': tokenizer.vocab_size,\n",
|
||||
" 'embed_dim': 128,\n",
|
||||
" 'num_heads': 4,\n",
|
||||
" 'num_layers': 3,\n",
|
||||
" 'max_len': seq_length\n",
|
||||
" }\n",
|
||||
" \n",
|
||||
" # Инициализация модели\n",
|
||||
" model = GPT(config)\n",
|
||||
" optimizer = optim.Adam(model.parameters(), lr=0.001)\n",
|
||||
" criterion = torch.nn.CrossEntropyLoss()\n",
|
||||
" \n",
|
||||
" # Процесс обучения\n",
|
||||
" num_epochs = 5\n",
|
||||
" batch_size = 8\n",
|
||||
" \n",
|
||||
" for epoch in range(num_epochs):\n",
|
||||
" total_loss = 0\n",
|
||||
" model.train()\n",
|
||||
" \n",
|
||||
" for i in range(0, len(examples), batch_size):\n",
|
||||
" batch = examples[i:i+batch_size]\n",
|
||||
" inputs = torch.tensor([ex[0] for ex in batch])\n",
|
||||
" targets = torch.tensor([ex[1] for ex in batch])\n",
|
||||
" \n",
|
||||
" optimizer.zero_grad()\n",
|
||||
" outputs = model(inputs)\n",
|
||||
" loss = criterion(outputs.view(-1, config['vocab_size']), targets.view(-1))\n",
|
||||
" loss.backward()\n",
|
||||
" optimizer.step()\n",
|
||||
" \n",
|
||||
" total_loss += loss.item()\n",
|
||||
" \n",
|
||||
" avg_loss = total_loss / (len(examples) / batch_size)\n",
|
||||
" print(f\"Эпоха {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")\n",
|
||||
" \n",
|
||||
" print(\"Обучение завершено!\")\n",
|
||||
"else:\n",
|
||||
" print(\"Невозможно начать обучение: нет подготовленных данных\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 6. Генерация текста"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def generate_text(model, tokenizer, prompt, max_len=50, temperature=0.7):\n",
|
||||
" model.eval()\n",
|
||||
" tokens = tokenizer.encode(prompt)\n",
|
||||
" \n",
|
||||
" for _ in range(max_len):\n",
|
||||
" input_ids = torch.tensor([tokens[-config['max_len']:]])\n",
|
||||
" with torch.no_grad():\n",
|
||||
" logits = model(input_ids)[0, -1, :] / temperature\n",
|
||||
" probs = torch.softmax(logits, dim=-1)\n",
|
||||
" next_token = torch.multinomial(probs, num_samples=1).item()\n",
|
||||
" tokens.append(next_token)\n",
|
||||
" \n",
|
||||
" return tokenizer.decode(tokens)\n",
|
||||
"\n",
|
||||
"if 'model' in locals():\n",
|
||||
" prompts = [\n",
|
||||
" \"Сегодня прекрасный день,\",\n",
|
||||
" \"Искусственный интеллект\",\n",
|
||||
" \"В далеком будущем\"\n",
|
||||
" ]\n",
|
||||
" \n",
|
||||
" for prompt in prompts:\n",
|
||||
" generated = generate_text(model, tokenizer, prompt)\n",
|
||||
" print(f\"Промпт: '{prompt}'\")\n",
|
||||
" print(f\"Результат: {generated}\\n\")\n",
|
||||
"else:\n",
|
||||
" print(\"Модель не обучена, генерация невозможна\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 7. Сохранение и загрузка моделей"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def save_model(model, tokenizer, model_name):\n",
|
||||
" model_path = f\"data/model/{model_name}.pth\"\n",
|
||||
" tokenizer_path = f\"data/tokenizer/{model_name}_tokenizer.json\"\n",
|
||||
" \n",
|
||||
" torch.save(model.state_dict(), model_path)\n",
|
||||
" tokenizer.save(tokenizer_path)\n",
|
||||
" print(f\"Модель сохранена в {model_path}\")\n",
|
||||
" print(f\"Токенизатор сохранен в {tokenizer_path}\")\n",
|
||||
"\n",
|
||||
"def load_model(model_name, config):\n",
|
||||
" model_path = f\"data/model/{model_name}.pth\"\n",
|
||||
" tokenizer_path = f\"data/tokenizer/{model_name}_tokenizer.json\"\n",
|
||||
" \n",
|
||||
" model = GPT(config)\n",
|
||||
" model.load_state_dict(torch.load(model_path))\n",
|
||||
" \n",
|
||||
" tokenizer = BPETokenizer()\n",
|
||||
" tokenizer.load(tokenizer_path)\n",
|
||||
" \n",
|
||||
" print(f\"Модель загружена из {model_path}\")\n",
|
||||
" return model, tokenizer\n",
|
||||
"\n",
|
||||
"# Пример использования:\n",
|
||||
"# save_model(model, tokenizer, \"my_first_model\")\n",
|
||||
"# loaded_model, loaded_tokenizer = load_model(\"my_first_model\", config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 8. Советы по улучшению\n",
|
||||
"\n",
|
||||
"1. Для лучших результатов:\n",
|
||||
" - Увеличьте размер корпуса\n",
|
||||
" - Добавьте больше эпох обучения\n",
|
||||
" - Настройте параметры модели\n",
|
||||
"\n",
|
||||
"2. Экспериментируйте с:\n",
|
||||
" - Температурой генерации (0.1-1.0)\n",
|
||||
" - Разными промптами\n",
|
||||
" - Архитектурой модели\n",
|
||||
"\n",
|
||||
"3. Дополнительные возможности:\n",
|
||||
" - Визуализация attention-карт\n",
|
||||
" - Реализация beam search\n",
|
||||
" - Fine-tuning на специфичных данных"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
202
tests/test_resume_training.py
Normal file
202
tests/test_resume_training.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import os
|
||||
import tempfile
|
||||
import torch
|
||||
import pytest
|
||||
from simple_llm.transformer.gpt import GPT
|
||||
from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
from simple_llm.transformer.callback import (
|
||||
LRSchedulerCallback,
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
# Создаем тестовые данные
|
||||
inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10
|
||||
targets = torch.randint(0, 100, (100, 10))
|
||||
return DataLoader(TensorDataset(inputs, targets), batch_size=10)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_model():
|
||||
return GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
emb_size=32,
|
||||
num_heads=4,
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
def test_model_checkpoint_saving(sample_model, sample_data):
|
||||
"""Тестирует корректность сохранения чекпоинтов"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
files = os.listdir(tmpdir)
|
||||
assert len(files) == 1
|
||||
assert files[0].startswith('checkpoint_epoch_')
|
||||
|
||||
checkpoint = torch.load(os.path.join(tmpdir, files[0]))
|
||||
assert 'model_state_dict' in checkpoint
|
||||
assert 'optimizer_state_dict' in checkpoint
|
||||
assert 'epoch' in checkpoint
|
||||
assert 'train_loss' in checkpoint
|
||||
|
||||
def test_resume_training(sample_model, sample_data):
|
||||
"""Тестирует восстановление обучения из чекпоинта"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
new_model = GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
emb_size=32,
|
||||
num_heads=4,
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
new_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
assert resume_cb.last_epoch == 0
|
||||
files = os.listdir(tmpdir)
|
||||
assert len(files) == 2
|
||||
|
||||
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
||||
"""Тестирует поведение при отсутствии чекпоинтов"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
assert len(os.listdir(tmpdir)) == 0
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
assert resume_cb.last_epoch == -1
|
||||
|
||||
def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
||||
"""Тестирует обработку битых чекпоинтов"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt")
|
||||
with open(bad_checkpoint, 'w') as f:
|
||||
f.write("corrupted data")
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
def test_optimizer_state_restoration(sample_model, sample_data):
|
||||
"""Тестирует восстановление состояния оптимизатора"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
original_optimizer_state = sample_model.optimizer.state_dict()
|
||||
|
||||
new_model = GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
emb_size=32,
|
||||
num_heads=4,
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
new_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
assert 'state' in new_model.optimizer.state_dict()
|
||||
assert 'param_groups' in new_model.optimizer.state_dict()
|
||||
|
||||
# Проверяем только параметры, кроме lr (так как он меняется scheduler'ом)
|
||||
for key in original_optimizer_state['param_groups'][0]:
|
||||
if key not in ['params', 'lr']:
|
||||
assert (
|
||||
original_optimizer_state['param_groups'][0][key] ==
|
||||
new_model.optimizer.state_dict()['param_groups'][0][key]
|
||||
)
|
||||
|
||||
def test_multiple_resumes(sample_model, sample_data):
|
||||
"""Тестирует многократное восстановление обучения"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=3,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
files = os.listdir(tmpdir)
|
||||
assert len(files) == 3
|
||||
|
||||
def test_scheduler_state_restoration(sample_model, sample_data):
|
||||
"""Тестирует восстановление состояния LR"""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir)
|
||||
lr_scheduler_cb = LRSchedulerCallback(lr=0.001)
|
||||
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[checkpoint_cb, lr_scheduler_cb],
|
||||
learning_rate=0.001
|
||||
)
|
||||
|
||||
# Сохраняем текущий lr с учетом decay
|
||||
expected_lr = 0.001 * (0.95 ** 1) # decay^epoch
|
||||
|
||||
new_model = GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
emb_size=32,
|
||||
num_heads=4,
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
new_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb],
|
||||
resume_training=True,
|
||||
learning_rate=0.001
|
||||
)
|
||||
|
||||
# Проверяем что LR восстановлен с учетом decay
|
||||
assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)
|
||||
Reference in New Issue
Block a user