diff --git a/simple_llm/transformer/callback/__init__.py b/simple_llm/transformer/callback/__init__.py index c2fbc19..83a6dbd 100644 --- a/simple_llm/transformer/callback/__init__.py +++ b/simple_llm/transformer/callback/__init__.py @@ -7,14 +7,17 @@ Callback-система для управления обучением GPT. - LRSchedulerCallback - регулировка learning rate """ +# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/__init__.py from .callback import Callback from .early_stopping_callback import EarlyStoppingCallback from .lrs_scheduler_callback import LRSchedulerCallback from .model_checkpoint_callback import ModelCheckpointCallback +from .resume_training_callback import ResumeTrainingCallback __all__ = [ 'Callback', 'EarlyStoppingCallback', - 'LRSchedulerCallback', - 'ModelCheckpointCallback' + 'LRSchedulerCallback', + 'ModelCheckpointCallback', + 'ResumeTrainingCallback' ] \ No newline at end of file diff --git a/simple_llm/transformer/callback/callback.py b/simple_llm/transformer/callback/callback.py index 3b506cb..34eceac 100644 --- a/simple_llm/transformer/callback/callback.py +++ b/simple_llm/transformer/callback/callback.py @@ -12,6 +12,9 @@ class Callback: - on_batch_end - после обработки батча - on_epoch_end - в конце эпохи """ + + def on_train_begin(self, model): + pass def on_epoch_begin(self, epoch, model): """Вызывается перед началом эпохи. diff --git a/simple_llm/transformer/callback/lrs_scheduler_callback.py b/simple_llm/transformer/callback/lrs_scheduler_callback.py index ac35f8d..cbcf77a 100644 --- a/simple_llm/transformer/callback/lrs_scheduler_callback.py +++ b/simple_llm/transformer/callback/lrs_scheduler_callback.py @@ -14,8 +14,22 @@ class LRSchedulerCallback(Callback): def __init__(self, lr, decay=0.95): self.base_lr = lr self.decay = decay + self.last_epoch = -1 # Добавляем отслеживание эпохи + + def get_state(self): + return { + 'base_lr': self.base_lr, + 'decay': self.decay, + 'last_epoch': self.last_epoch + } + + def set_state(self, state): + self.base_lr = state['base_lr'] + self.decay = state['decay'] + self.last_epoch = state['last_epoch'] def on_epoch_begin(self, epoch, model): + self.last_epoch = epoch # Сохраняем текущую эпоху new_lr = self.base_lr * (self.decay ** epoch) for param_group in model.optimizer.param_groups: param_group['lr'] = new_lr \ No newline at end of file diff --git a/simple_llm/transformer/callback/model_checkpoint_callback.py b/simple_llm/transformer/callback/model_checkpoint_callback.py index 3f40f8a..15cdad8 100644 --- a/simple_llm/transformer/callback/model_checkpoint_callback.py +++ b/simple_llm/transformer/callback/model_checkpoint_callback.py @@ -1,6 +1,7 @@ from .callback import Callback import torch import os +from typing import Optional class ModelCheckpointCallback(Callback): """Сохраняет чекпоинты модели во время обучения. @@ -11,24 +12,64 @@ class ModelCheckpointCallback(Callback): Args: save_dir (str): Директория для сохранения - save_best_only (bool): Сохранять только лучшие модели + save_best_only (bool): Если True, сохраняет только при улучшении loss + save_freq (int): Сохранять каждые N эпох (default=1) + monitor (str): Какой loss мониторить ('val' или 'train') """ - def __init__(self, save_dir, save_best_only=True): + def __init__(self, + save_dir: str, + save_best_only: bool = True, + save_freq: int = 1, + monitor: str = 'val'): self.save_dir = save_dir self.save_best_only = save_best_only + self.save_freq = save_freq + self.monitor = monitor self.best_loss = float('inf') - def on_epoch_end(self, epoch, model, train_loss, val_loss): - if not os.path.exists(self.save_dir): - os.makedirs(self.save_dir) - - current_loss = val_loss if val_loss else train_loss + # Создаем директорию если её нет + os.makedirs(save_dir, exist_ok=True) - if not self.save_best_only or current_loss < self.best_loss: + def on_epoch_end(self, epoch, model, train_loss, val_loss): + # Решаем какой loss использовать для сравнения + current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss + + # Сохраняем по расписанию или при улучшении + should_save = ( + (epoch + 1) % self.save_freq == 0 or # по расписанию + (self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель + ) + + if should_save: self.best_loss = current_loss - path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pt") + checkpoint_path = os.path.join( + self.save_dir, + f"checkpoint_epoch_{epoch}.pt" + ) + + # Собираем состояния всех callback'ов + callback_states = {} + if hasattr(model, '_callbacks'): + for cb in model._callbacks: + if hasattr(cb, 'get_state'): + callback_states[cb.__class__.__name__] = cb.get_state() + torch.save({ 'epoch': epoch, - 'model_state': model.state_dict(), - 'loss': current_loss - }, path) \ No newline at end of file + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': model.optimizer.state_dict(), + 'train_loss': train_loss, + 'val_loss': val_loss, + 'best_loss': self.best_loss, + 'callback_states': callback_states, + 'config': { + 'vocab_size': model._vocab_size, + 'max_seq_len': model._max_seq_len, + 'emb_size': model._emb_size, + 'num_heads': model._num_heads, + 'head_size': model._head_size, + 'num_layers': model._num_layers + } + }, checkpoint_path) + + print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})") diff --git a/simple_llm/transformer/callback/resume_training_callback.py b/simple_llm/transformer/callback/resume_training_callback.py new file mode 100644 index 0000000..cc71bc7 --- /dev/null +++ b/simple_llm/transformer/callback/resume_training_callback.py @@ -0,0 +1,53 @@ +# /Users/sergey/Projects/ML/simple-llm/simple_llm/transformer/callback/resume_training_callback.py +import os +import torch +from typing import Optional +from .callback import Callback + +class ResumeTrainingCallback(Callback): + """Callback для восстановления обучения с последнего чекпоинта""" + + def __init__(self, checkpoint_dir: str, resume: bool = True): + """ + Args: + checkpoint_dir: Путь к директории с чекпоинтами + resume: Флаг восстановления обучения (default=True) + """ + self.checkpoint_dir = checkpoint_dir + self.resume = resume + self.last_epoch = -1 + + def on_train_begin(self, model): + if not self.resume: + return + + checkpoint_path = self._find_latest_checkpoint() + if checkpoint_path: + print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=model._device) + + # Убедимся, что загружаем на правильное устройство + model.load_state_dict(checkpoint['model_state_dict']) + if 'optimizer_state_dict' in checkpoint: + model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None: + if hasattr(model, 'scheduler'): + model.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + self.last_epoch = checkpoint.get('epoch', -1) + + print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}") + print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n") + + def _find_latest_checkpoint(self) -> Optional[str]: + if not os.path.exists(self.checkpoint_dir): + return None + + checkpoints = [f for f in os.listdir(self.checkpoint_dir) + if f.startswith('checkpoint_') and f.endswith('.pt')] + + if not checkpoints: + return None + + # Сортируем по времени создания + checkpoints.sort(key=lambda x: os.path.getmtime(os.path.join(self.checkpoint_dir, x))) + return os.path.join(self.checkpoint_dir, checkpoints[-1]) \ No newline at end of file diff --git a/simple_llm_demo.ipynb b/simple_llm_demo.ipynb new file mode 100644 index 0000000..c0336fc --- /dev/null +++ b/simple_llm_demo.ipynb @@ -0,0 +1,324 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Демонстрация simple_llm\n", + "## Полное руководство по установке и использованию" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Установка и настройка\n", + "\n", + "### Клонирование репозитория:\n", + "```bash\n", + "git clone https://github.com/ваш_username/simple-llm.git\n", + "cd simple-llm\n", + "```\n", + "\n", + "### Установка зависимостей:\n", + "```bash\n", + "pip install -e .\n", + "pip install torch tqdm\n", + "```\n", + "\n", + "### Проверка структуры данных:\n", + "```bash\n", + "mkdir -p data/corpus/sample data/model data/tokenizer\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Инициализация и проверка окружения" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "import os\n", + "import torch\n", + "\n", + "# Проверка версии PyTorch\n", + "print(f\"PyTorch version: {torch.__version__}\")\n", + "\n", + "# Добавление пути к библиотеке\n", + "project_path = os.path.abspath('../simple-llm')\n", + "sys.path.append(project_path)\n", + "print(f\"Путь к проекту: {project_path}\")\n", + "\n", + "# Проверка модулей\n", + "try:\n", + " from simple_llm.tokenizer.bpe import BPETokenizer\n", + " from simple_llm.data.get_data import load_text_corpus\n", + " from simple_llm.transformer.gpt import GPT\n", + " print(\"✓ Все модули успешно импортированы\")\n", + "except ImportError as e:\n", + " print(f\"✗ Ошибка: {e}\")\n", + " print(\"Решение: выполните 'pip install -e .' из корня проекта\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Работа с токенизатором" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Инициализация токенизатора\n", + "tokenizer = BPETokenizer()\n", + "\n", + "# Загрузка и обработка текста\n", + "corpus_path = 'data/corpus/sample/'\n", + "if os.path.exists(corpus_path):\n", + " text = load_text_corpus(corpus_path)\n", + " print(f\"Загружено текста: {len(text.split())} слов\")\n", + " \n", + " # Обучение токенизатора\n", + " tokenizer.train(text, vocab_size=1000)\n", + " print(f\"Токенизатор обучен, размер словаря: {tokenizer.vocab_size}\")\n", + " \n", + " # Тест токенизации\n", + " test_phrase = \"Пример работы токенизатора\"\n", + " tokens = tokenizer.encode(test_phrase)\n", + " print(f\"Текст: {test_phrase}\")\n", + " print(f\"Токены: {tokens}\")\n", + " print(f\"Обратное преобразование: {tokenizer.decode(tokens)}\")\n", + "else:\n", + " print(f\"Директория {corpus_path} не содержит данных для обучения\")\n", + " print(\"Добавьте текстовые файлы в формате .txt в эту директорию\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Подготовка данных для обучения" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if 'text' in locals():\n", + " # Токенизация всего корпуса\n", + " all_tokens = tokenizer.encode(text)\n", + " \n", + " # Создание обучающих последовательностей\n", + " seq_length = 64\n", + " examples = []\n", + " for i in range(0, len(all_tokens) - seq_length - 1, seq_length):\n", + " input_seq = all_tokens[i:i+seq_length]\n", + " target_seq = all_tokens[i+1:i+seq_length+1]\n", + " examples.append((input_seq, target_seq))\n", + " \n", + " print(f\"Создано обучающих примеров: {len(examples)}\")\n", + " print(f\"Размер последовательности: {seq_length} токенов\")\n", + " print(f\"Пример входных данных: {examples[0][0][:10]}...\")\n", + " print(f\"Пример целевых данных: {examples[0][1][:10]}...\")\n", + "else:\n", + " print(\"Невозможно подготовить данные: текст не загружен\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 5. Обучение модели GPT" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if 'examples' in locals() and len(examples) > 0:\n", + " # Конфигурация модели\n", + " config = {\n", + " 'vocab_size': tokenizer.vocab_size,\n", + " 'embed_dim': 128,\n", + " 'num_heads': 4,\n", + " 'num_layers': 3,\n", + " 'max_len': seq_length\n", + " }\n", + " \n", + " # Инициализация модели\n", + " model = GPT(config)\n", + " optimizer = optim.Adam(model.parameters(), lr=0.001)\n", + " criterion = torch.nn.CrossEntropyLoss()\n", + " \n", + " # Процесс обучения\n", + " num_epochs = 5\n", + " batch_size = 8\n", + " \n", + " for epoch in range(num_epochs):\n", + " total_loss = 0\n", + " model.train()\n", + " \n", + " for i in range(0, len(examples), batch_size):\n", + " batch = examples[i:i+batch_size]\n", + " inputs = torch.tensor([ex[0] for ex in batch])\n", + " targets = torch.tensor([ex[1] for ex in batch])\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs.view(-1, config['vocab_size']), targets.view(-1))\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " total_loss += loss.item()\n", + " \n", + " avg_loss = total_loss / (len(examples) / batch_size)\n", + " print(f\"Эпоха {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}\")\n", + " \n", + " print(\"Обучение завершено!\")\n", + "else:\n", + " print(\"Невозможно начать обучение: нет подготовленных данных\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 6. Генерация текста" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_text(model, tokenizer, prompt, max_len=50, temperature=0.7):\n", + " model.eval()\n", + " tokens = tokenizer.encode(prompt)\n", + " \n", + " for _ in range(max_len):\n", + " input_ids = torch.tensor([tokens[-config['max_len']:]])\n", + " with torch.no_grad():\n", + " logits = model(input_ids)[0, -1, :] / temperature\n", + " probs = torch.softmax(logits, dim=-1)\n", + " next_token = torch.multinomial(probs, num_samples=1).item()\n", + " tokens.append(next_token)\n", + " \n", + " return tokenizer.decode(tokens)\n", + "\n", + "if 'model' in locals():\n", + " prompts = [\n", + " \"Сегодня прекрасный день,\",\n", + " \"Искусственный интеллект\",\n", + " \"В далеком будущем\"\n", + " ]\n", + " \n", + " for prompt in prompts:\n", + " generated = generate_text(model, tokenizer, prompt)\n", + " print(f\"Промпт: '{prompt}'\")\n", + " print(f\"Результат: {generated}\\n\")\n", + "else:\n", + " print(\"Модель не обучена, генерация невозможна\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 7. Сохранение и загрузка моделей" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def save_model(model, tokenizer, model_name):\n", + " model_path = f\"data/model/{model_name}.pth\"\n", + " tokenizer_path = f\"data/tokenizer/{model_name}_tokenizer.json\"\n", + " \n", + " torch.save(model.state_dict(), model_path)\n", + " tokenizer.save(tokenizer_path)\n", + " print(f\"Модель сохранена в {model_path}\")\n", + " print(f\"Токенизатор сохранен в {tokenizer_path}\")\n", + "\n", + "def load_model(model_name, config):\n", + " model_path = f\"data/model/{model_name}.pth\"\n", + " tokenizer_path = f\"data/tokenizer/{model_name}_tokenizer.json\"\n", + " \n", + " model = GPT(config)\n", + " model.load_state_dict(torch.load(model_path))\n", + " \n", + " tokenizer = BPETokenizer()\n", + " tokenizer.load(tokenizer_path)\n", + " \n", + " print(f\"Модель загружена из {model_path}\")\n", + " return model, tokenizer\n", + "\n", + "# Пример использования:\n", + "# save_model(model, tokenizer, \"my_first_model\")\n", + "# loaded_model, loaded_tokenizer = load_model(\"my_first_model\", config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8. Советы по улучшению\n", + "\n", + "1. Для лучших результатов:\n", + " - Увеличьте размер корпуса\n", + " - Добавьте больше эпох обучения\n", + " - Настройте параметры модели\n", + "\n", + "2. Экспериментируйте с:\n", + " - Температурой генерации (0.1-1.0)\n", + " - Разными промптами\n", + " - Архитектурой модели\n", + "\n", + "3. Дополнительные возможности:\n", + " - Визуализация attention-карт\n", + " - Реализация beam search\n", + " - Fine-tuning на специфичных данных" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/tests/test_resume_training.py b/tests/test_resume_training.py new file mode 100644 index 0000000..08b8ff8 --- /dev/null +++ b/tests/test_resume_training.py @@ -0,0 +1,202 @@ +import os +import tempfile +import torch +import pytest +from simple_llm.transformer.gpt import GPT +from simple_llm.transformer.callback import ModelCheckpointCallback, ResumeTrainingCallback +from torch.utils.data import DataLoader, TensorDataset + +from simple_llm.transformer.callback import ( + LRSchedulerCallback, +) + +@pytest.fixture +def sample_data(): + # Создаем тестовые данные + inputs = torch.randint(0, 100, (100, 10)) # 100 samples, seq_len=10 + targets = torch.randint(0, 100, (100, 10)) + return DataLoader(TensorDataset(inputs, targets), batch_size=10) + +@pytest.fixture +def sample_model(): + return GPT( + vocab_size=100, + max_seq_len=10, + emb_size=32, + num_heads=4, + head_size=8, + num_layers=2 + ) + +def test_model_checkpoint_saving(sample_model, sample_data): + """Тестирует корректность сохранения чекпоинтов""" + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False) + sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb]) + + files = os.listdir(tmpdir) + assert len(files) == 1 + assert files[0].startswith('checkpoint_epoch_') + + checkpoint = torch.load(os.path.join(tmpdir, files[0])) + assert 'model_state_dict' in checkpoint + assert 'optimizer_state_dict' in checkpoint + assert 'epoch' in checkpoint + assert 'train_loss' in checkpoint + +def test_resume_training(sample_model, sample_data): + """Тестирует восстановление обучения из чекпоинта""" + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False) + sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb]) + + new_model = GPT( + vocab_size=100, + max_seq_len=10, + emb_size=32, + num_heads=4, + head_size=8, + num_layers=2 + ) + + resume_cb = ResumeTrainingCallback(tmpdir) + new_model.fit( + sample_data, + num_epoch=2, + callbacks=[resume_cb, checkpoint_cb], + resume_training=True + ) + + assert resume_cb.last_epoch == 0 + files = os.listdir(tmpdir) + assert len(files) == 2 + +def test_resume_with_missing_checkpoint(sample_model, sample_data): + """Тестирует поведение при отсутствии чекпоинтов""" + with tempfile.TemporaryDirectory() as tmpdir: + assert len(os.listdir(tmpdir)) == 0 + + resume_cb = ResumeTrainingCallback(tmpdir) + sample_model.fit( + sample_data, + num_epoch=1, + callbacks=[resume_cb], + resume_training=True + ) + + assert resume_cb.last_epoch == -1 + +def test_resume_with_corrupted_checkpoint(sample_model, sample_data): + """Тестирует обработку битых чекпоинтов""" + with tempfile.TemporaryDirectory() as tmpdir: + bad_checkpoint = os.path.join(tmpdir, "checkpoint_epoch_0.pt") + with open(bad_checkpoint, 'w') as f: + f.write("corrupted data") + + resume_cb = ResumeTrainingCallback(tmpdir) + + with pytest.raises(Exception): + sample_model.fit( + sample_data, + num_epoch=1, + callbacks=[resume_cb], + resume_training=True + ) + +def test_optimizer_state_restoration(sample_model, sample_data): + """Тестирует восстановление состояния оптимизатора""" + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_cb = ModelCheckpointCallback(tmpdir) + sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb]) + + original_optimizer_state = sample_model.optimizer.state_dict() + + new_model = GPT( + vocab_size=100, + max_seq_len=10, + emb_size=32, + num_heads=4, + head_size=8, + num_layers=2 + ) + + resume_cb = ResumeTrainingCallback(tmpdir) + new_model.fit( + sample_data, + num_epoch=2, + callbacks=[resume_cb, checkpoint_cb], + resume_training=True + ) + + assert 'state' in new_model.optimizer.state_dict() + assert 'param_groups' in new_model.optimizer.state_dict() + + # Проверяем только параметры, кроме lr (так как он меняется scheduler'ом) + for key in original_optimizer_state['param_groups'][0]: + if key not in ['params', 'lr']: + assert ( + original_optimizer_state['param_groups'][0][key] == + new_model.optimizer.state_dict()['param_groups'][0][key] + ) + +def test_multiple_resumes(sample_model, sample_data): + """Тестирует многократное восстановление обучения""" + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_cb = ModelCheckpointCallback(tmpdir) + sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb]) + + resume_cb = ResumeTrainingCallback(tmpdir) + sample_model.fit( + sample_data, + num_epoch=2, + callbacks=[resume_cb, checkpoint_cb], + resume_training=True + ) + + resume_cb = ResumeTrainingCallback(tmpdir) + sample_model.fit( + sample_data, + num_epoch=3, + callbacks=[resume_cb, checkpoint_cb], + resume_training=True + ) + + files = os.listdir(tmpdir) + assert len(files) == 3 + +def test_scheduler_state_restoration(sample_model, sample_data): + """Тестирует восстановление состояния LR""" + with tempfile.TemporaryDirectory() as tmpdir: + checkpoint_cb = ModelCheckpointCallback(tmpdir) + lr_scheduler_cb = LRSchedulerCallback(lr=0.001) + + sample_model.fit( + sample_data, + num_epoch=1, + callbacks=[checkpoint_cb, lr_scheduler_cb], + learning_rate=0.001 + ) + + # Сохраняем текущий lr с учетом decay + expected_lr = 0.001 * (0.95 ** 1) # decay^epoch + + new_model = GPT( + vocab_size=100, + max_seq_len=10, + emb_size=32, + num_heads=4, + head_size=8, + num_layers=2 + ) + + resume_cb = ResumeTrainingCallback(tmpdir) + new_model.fit( + sample_data, + num_epoch=2, + callbacks=[resume_cb, checkpoint_cb, lr_scheduler_cb], + resume_training=True, + learning_rate=0.001 + ) + + # Проверяем что LR восстановлен с учетом decay + assert new_model.optimizer.param_groups[0]['lr'] == pytest.approx(expected_lr)