diff --git a/README.md b/README.md index 755d0d7..33e9710 100644 --- a/README.md +++ b/README.md @@ -30,6 +30,7 @@ source venv/bin/activate # Linux/Mac # 3. Установите зависимости pip install torch pip install dill tqdm # Основные зависимости для работы +pip install pytest # Установка simple_llm пакета pip install . @@ -94,8 +95,10 @@ python bin/train_gpt_model.py \ --epochs 3 \ --emb-size 64 \ --num-heads 2 \ - --num-layers 2 + --num-layers 2 \ + --keep-last-n 3 ``` +> Аргумент `--keep-last-n` определяет, сколько последних чекпоинтов (snapshot'ов обучения) хранить в папке модели. По умолчанию — 3. Старые файлы удаляются автоматически для экономии места. ### ✔️ Восстановление обучения с чекпоинта @@ -130,7 +133,7 @@ python bin/generate_text.py \ # Последовательно выполните: ./bin/train_tokenizer.py --corpus data/corpus/sample --output data/tokenizer/bpe.json ./bin/tokenize_corpus.py --corpus data/corpus/sample --tokenizer data/tokenizer/bpe.json -./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json +./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json --keep-last-n 3 ./bin/generate_text.py --model data/model/gpt_model.pth --tokenizer data/tokenizer/bpe.json --prompt "Привет" ``` diff --git a/bin/train_gpt_model.py b/bin/train_gpt_model.py index a849ed2..4f339ed 100755 --- a/bin/train_gpt_model.py +++ b/bin/train_gpt_model.py @@ -42,6 +42,7 @@ def main(): parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') + parser.add_argument('--keep-last-n', type=int, default=3, help='Сколько последних чекпоинтов хранить') args = parser.parse_args() # Проверяем и создаем директорию для сохранения @@ -88,7 +89,8 @@ def main(): learning_rate=args.lr, checkpoint_dir=output_dir, resume_training=True, - start_epoch=start_epoch + start_epoch=start_epoch, + keep_last_n=args.keep_last_n ) torch.save(model.state_dict(), args.output) diff --git a/simple_llm/transformer/callback/model_checkpoint_callback.py b/simple_llm/transformer/callback/model_checkpoint_callback.py index 15cdad8..b2edfd7 100644 --- a/simple_llm/transformer/callback/model_checkpoint_callback.py +++ b/simple_llm/transformer/callback/model_checkpoint_callback.py @@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback): save_best_only (bool): Если True, сохраняет только при улучшении loss save_freq (int): Сохранять каждые N эпох (default=1) monitor (str): Какой loss мониторить ('val' или 'train') + keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3) """ def __init__(self, save_dir: str, save_best_only: bool = True, save_freq: int = 1, - monitor: str = 'val'): + monitor: str = 'val', + keep_last_n: int = 3): self.save_dir = save_dir self.save_best_only = save_best_only self.save_freq = save_freq self.monitor = monitor + self.keep_last_n = keep_last_n self.best_loss = float('inf') # Создаем директорию если её нет os.makedirs(save_dir, exist_ok=True) - def on_epoch_end(self, epoch, model, train_loss, val_loss): + def on_epoch_end(self, global_epoch, model, train_loss, val_loss): # Решаем какой loss использовать для сравнения current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss # Сохраняем по расписанию или при улучшении should_save = ( - (epoch + 1) % self.save_freq == 0 or # по расписанию + (global_epoch + 1) % self.save_freq == 0 or # по расписанию (self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель ) @@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback): self.best_loss = current_loss checkpoint_path = os.path.join( self.save_dir, - f"checkpoint_epoch_{epoch}.pt" + f"checkpoint_epoch_{global_epoch}.pt" ) # Собираем состояния всех callback'ов @@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback): callback_states[cb.__class__.__name__] = cb.get_state() torch.save({ - 'epoch': epoch, + 'epoch': global_epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': model.optimizer.state_dict(), 'train_loss': train_loss, @@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback): }, checkpoint_path) print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})") + self._clean_old_checkpoints() + + def _clean_old_checkpoints(self): + import glob + files = sorted( + glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')), + key=os.path.getmtime + ) + if len(files) > self.keep_last_n: + for file in files[:-self.keep_last_n]: + try: + os.remove(file) + print(f"Удалён старый чекпоинт: {file}") + except Exception as e: + print(f"Ошибка при удалении чекпоинта {file}: {e}") \ No newline at end of file diff --git a/simple_llm/transformer/callback/resume_training_callback.py b/simple_llm/transformer/callback/resume_training_callback.py index cc71bc7..f91c2be 100644 --- a/simple_llm/transformer/callback/resume_training_callback.py +++ b/simple_llm/transformer/callback/resume_training_callback.py @@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback): def on_train_begin(self, model): if not self.resume: return - checkpoint_path = self._find_latest_checkpoint() if checkpoint_path: - print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=model._device) - - # Убедимся, что загружаем на правильное устройство - model.load_state_dict(checkpoint['model_state_dict']) - if 'optimizer_state_dict' in checkpoint: - model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None: - if hasattr(model, 'scheduler'): - model.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) - self.last_epoch = checkpoint.get('epoch', -1) - - print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}") - print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n") + try: + print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=model._device) + # Убедимся, что загружаем на правильное устройство + model.load_state_dict(checkpoint['model_state_dict']) + if 'optimizer_state_dict' in checkpoint: + model.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None: + if hasattr(model, 'scheduler'): + model.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) + self.last_epoch = checkpoint.get('epoch', -1) + print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}") + print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n") + except Exception as e: + print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}") + # Найти максимальный существующий checkpoint по файловой системе + import glob, os + cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt')) + if cp_files: + try: + self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files]) + except Exception: + self.last_epoch = -1 + else: + self.last_epoch = -1 + else: + # Если файлов совсем нет + self.last_epoch = -1 def _find_latest_checkpoint(self) -> Optional[str]: if not os.path.exists(self.checkpoint_dir): diff --git a/simple_llm/transformer/gpt.py b/simple_llm/transformer/gpt.py index df5b30d..18b0d3c 100644 --- a/simple_llm/transformer/gpt.py +++ b/simple_llm/transformer/gpt.py @@ -297,7 +297,8 @@ class GPT(nn.Module): callbacks=None, checkpoint_dir=None, resume_training=False, - start_epoch=0 + start_epoch=0, + keep_last_n: int = 3 ): """Обучает модель GPT с поддержкой callback-системы. @@ -379,7 +380,7 @@ class GPT(nn.Module): ]) if checkpoint_dir: - callbacks.append(ModelCheckpointCallback(checkpoint_dir)) + callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n)) # Инициализация оптимизатора ДО вызова callback'ов self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) @@ -391,19 +392,29 @@ class GPT(nn.Module): start_epoch = cb.last_epoch + 1 break - # Начало обучения (после инициализации оптимизатора) - for cb in callbacks: - cb.on_train_begin(self) - - print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})") + # Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл) + import glob, os + if checkpoint_dir is not None: + cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt')) + if cp_files: + try: + max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files]) + if max_cp + 1 > start_epoch: + start_epoch = max_cp + 1 + except Exception: + pass + + final_epoch_num = start_epoch + num_epoch + print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})") print(f"Размер батча: {train_loader.batch_size}") print(f"Всего батчей: {len(train_loader)}") print(f"Устройство: {device}\n") - for epoch in range(start_epoch, start_epoch + num_epoch): + for epoch_local in range(num_epoch): + global_epoch = start_epoch + epoch_local # Вызов callback перед эпохой for cb in callbacks: - cb.on_epoch_begin(epoch, self) + cb.on_epoch_begin(global_epoch, self) self.train() epoch_loss = 0.0 @@ -411,7 +422,7 @@ class GPT(nn.Module): # Прогресс-бар для батчей batch_pbar = tqdm(train_loader, - desc=f"Эпоха {epoch+1}/{num_epoch}", + desc=f"Эпоха {global_epoch+1}/{final_epoch_num}", leave=False, bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]') @@ -447,7 +458,7 @@ class GPT(nn.Module): self.train_loss = epoch_loss / len(train_loader) epoch_time = time.time() - start_time - print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек") + print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек") print(f"Средний Train Loss: {self.train_loss:.4f}") if valid_loader is not None: @@ -456,7 +467,7 @@ class GPT(nn.Module): with torch.no_grad(): # Прогресс-бар для валидации valid_pbar = tqdm(valid_loader, - desc=f"Валидация {epoch+1}/{num_epoch}", + desc=f"Валидация {global_epoch+1}/{final_epoch_num}", leave=False) for inputs, targets in valid_pbar: @@ -473,11 +484,17 @@ class GPT(nn.Module): self.validation_loss = valid_loss / len(valid_loader) print(f"Средний Val Loss: {self.validation_loss:.4f}") - # Вызов callback после эпохи + # Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта) stop_training = False for cb in callbacks: - if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss): - stop_training = True + # ModelCheckpointCallback — строго global_epoch + if isinstance(cb, ModelCheckpointCallback): + if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss): + stop_training = True + else: + # Старые/другие коллбэки — epoch_local + if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss): + stop_training = True if stop_training: print("Обучение остановлено callback-ом") diff --git a/tests/test_resume_training.py b/tests/test_resume_training.py index 08b8ff8..f157a61 100644 --- a/tests/test_resume_training.py +++ b/tests/test_resume_training.py @@ -49,7 +49,11 @@ def test_resume_training(sample_model, sample_data): with tempfile.TemporaryDirectory() as tmpdir: checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False) sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb]) - + + # Проверим, что чекпоинт создан + files = os.listdir(tmpdir) + assert any(f.startswith("checkpoint_epoch_") for f in files) + new_model = GPT( vocab_size=100, max_seq_len=10, @@ -58,18 +62,21 @@ def test_resume_training(sample_model, sample_data): head_size=8, num_layers=2 ) - + resume_cb = ResumeTrainingCallback(tmpdir) + resume_cb.on_train_begin(new_model) + # После создания чекпоинта и on_train_begin, last_epoch должен быть 0 + assert resume_cb.last_epoch == 0 + + # Дальнейшее обучение с fit/resume new_model.fit( - sample_data, - num_epoch=2, + sample_data, + num_epoch=2, callbacks=[resume_cb, checkpoint_cb], resume_training=True ) - - assert resume_cb.last_epoch == 0 - files = os.listdir(tmpdir) - assert len(files) == 2 + files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')] + assert len(files) == 3 def test_resume_with_missing_checkpoint(sample_model, sample_data): """Тестирует поведение при отсутствии чекпоинтов""" @@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data): resume_cb = ResumeTrainingCallback(tmpdir) - with pytest.raises(Exception): - sample_model.fit( - sample_data, - num_epoch=1, - callbacks=[resume_cb], - resume_training=True - ) + sample_model.fit( + sample_data, + num_epoch=1, + callbacks=[resume_cb], + resume_training=True + ) + assert resume_cb.last_epoch == -1 def test_optimizer_state_restoration(sample_model, sample_data): """Тестирует восстановление состояния оптимизатора"""