From 25c067af4a278c7e13f678f11908f7517e65dfca Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Wed, 30 Jul 2025 22:06:26 +0300 Subject: [PATCH] =?UTF-8?q?docs:=20=D0=B0=D0=BA=D1=82=D1=83=D0=B0=D0=BB?= =?UTF-8?q?=D0=B8=D0=B7=D0=B8=D1=80=D0=BE=D0=B2=D0=B0=D0=BD=D0=BE=20=D0=BE?= =?UTF-8?q?=D0=BF=D0=B8=D1=81=D0=B0=D0=BD=D0=B8=D0=B5=20=D0=B0=D0=B2=D1=82?= =?UTF-8?q?=D0=BE=D0=B4=D0=BE=D0=B2=D0=BE=D1=81=D1=81=D1=82=D0=B0=D0=BD?= =?UTF-8?q?=D0=BE=D0=B2=D0=BB=D0=B5=D0=BD=D0=B8=D1=8F=20=D0=BE=D0=B1=D1=83?= =?UTF-8?q?=D1=87=D0=B5=D0=BD=D0=B8=D1=8F=20=D0=B8=20=D0=BE=D0=B1=D0=BD?= =?UTF-8?q?=D0=BE=D0=B2=D0=BB=D0=B5=D0=BD=D0=B0=20=D0=BB=D0=BE=D0=B3=D0=B8?= =?UTF-8?q?=D0=BA=D0=B0=20=D1=8D=D0=BF=D0=BE=D1=85=20=D0=B2=20GPT?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 14 ++++++++++++ bin/train_gpt_model.py | 15 ++++++++++-- doc/train_on_custom_data_ru.md | 15 ++++++++++++ simple_llm/transformer/gpt.py | 42 +++++++++++++++++++++++++--------- 4 files changed, 73 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 8182d3e..572421f 100644 --- a/README.md +++ b/README.md @@ -79,6 +79,20 @@ python bin/train_gpt_model.py \ --num-layers 2 ``` +### ✔️ Восстановление обучения с чекпоинта + +Если обучение было прервано или вы хотите дообучить модель: +- Просто перезапустите команду обучения с теми же параметрами (`bin/train_gpt_model.py ...`). +- Скрипт сам определит последний checkpoint (checkpoint_epoch_X.pt) в папке модели. +- Обучение продолжится с нужной эпохи, параметры останутся неизменными. +- В консоли вы увидите сообщения вроде: + ``` + ⚡ Восстанавливаем обучение с эпохи 12 + Восстановление обучения GPT + ... + Начало обучения GPT на 18 эпох (с 12 по 29) + ``` + ### 4. Генерация текста ```bash python bin/generate_text.py \ diff --git a/bin/train_gpt_model.py b/bin/train_gpt_model.py index 42578cd..a849ed2 100755 --- a/bin/train_gpt_model.py +++ b/bin/train_gpt_model.py @@ -73,11 +73,22 @@ def main(): ) # Обучение + # Определяем стартовую эпоху + start_epoch = 0 + if os.path.exists(output_dir): + checkpoint_files = [f for f in os.listdir(output_dir) if f.startswith('checkpoint_epoch_')] + if checkpoint_files: + last_epoch = max([int(f.split('_')[2].split('.')[0]) for f in checkpoint_files]) + start_epoch = last_epoch + 1 + print(f"⚡ Восстанавливаем обучение с эпохи {start_epoch}") + model.fit( train_loader=loader, - num_epoch=args.epochs, + num_epoch=args.epochs - start_epoch, learning_rate=args.lr, - checkpoint_dir=output_dir + checkpoint_dir=output_dir, + resume_training=True, + start_epoch=start_epoch ) torch.save(model.state_dict(), args.output) diff --git a/doc/train_on_custom_data_ru.md b/doc/train_on_custom_data_ru.md index db08b33..9c857dc 100644 --- a/doc/train_on_custom_data_ru.md +++ b/doc/train_on_custom_data_ru.md @@ -68,6 +68,21 @@ dataset = GetData(data=all_tokens, seq_len=seq_len, device='cuda') ``` ## 5. Обучение модели с помощью fit() + +### Автоматическое восстановление обучения + +- При запуске скрипта обучения в директории модели автоматически ищется последний checkpoint (`checkpoint_epoch_X.pt`) и обучение продолжается с соответствующей эпохи. +- Не нужно вручную указывать эпоху старта — всё происходит автоматически. +- Логи отображают, с какой эпохи продолжается обучение и сколько эпох осталось до конца (см. вывод в консоли). + +**Пример из логов:** +``` +⚡ Восстанавливаем обучение с эпохи 12 +Восстановление обучения GPT +... +Начало обучения GPT на 18 эпох (с 12 по 29) +``` +- Чтобы воспользоваться функцией автодовосстановления, просто повторно запустите команду обучения с теми же CLI-параметрами! ```python from torch.utils.data import DataLoader from simple_llm.transformer.gpt import GPT diff --git a/simple_llm/transformer/gpt.py b/simple_llm/transformer/gpt.py index 469e1b7..df5b30d 100644 --- a/simple_llm/transformer/gpt.py +++ b/simple_llm/transformer/gpt.py @@ -8,6 +8,7 @@ from simple_llm.transformer.decoder import Decoder from simple_llm.transformer.callback import ( EarlyStoppingCallback, LRSchedulerCallback, + ResumeTrainingCallback, ModelCheckpointCallback ) @@ -294,7 +295,9 @@ class GPT(nn.Module): num_epoch: int = 1, learning_rate: float = 0.001, callbacks=None, - checkpoint_dir=None + checkpoint_dir=None, + resume_training=False, + start_epoch=0 ): """Обучает модель GPT с поддержкой callback-системы. @@ -360,27 +363,44 @@ class GPT(nn.Module): device = torch.device(self._device) self.to(device) - # Инициализация callbacks по умолчанию + # Инициализация callback-ов if callbacks is None: callbacks = [] - # Добавляем обязательные callback-и - default_callbacks = [ + # Добавляем Resume callback если нужно + if resume_training and checkpoint_dir: + print(f"Восстановление обучения GPT") + callbacks.insert(0, ResumeTrainingCallback(checkpoint_dir, resume=True)) + + # Стандартные callback-и + callbacks.extend([ EarlyStoppingCallback(patience=5), LRSchedulerCallback(lr=learning_rate) - ] - callbacks = default_callbacks + callbacks - if checkpoint_dir is not None: - callbacks = [ModelCheckpointCallback(checkpoint_dir)] + callbacks - + ]) + + if checkpoint_dir: + callbacks.append(ModelCheckpointCallback(checkpoint_dir)) + + # Инициализация оптимизатора ДО вызова callback'ов self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate) + + # Определяем стартовую эпоху + start_epoch = 0 + for cb in callbacks: + if isinstance(cb, ResumeTrainingCallback) and cb.last_epoch >= 0: + start_epoch = cb.last_epoch + 1 + break + + # Начало обучения (после инициализации оптимизатора) + for cb in callbacks: + cb.on_train_begin(self) - print(f"\nНачало обучения GPT на {num_epoch} эпох") + print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})") print(f"Размер батча: {train_loader.batch_size}") print(f"Всего батчей: {len(train_loader)}") print(f"Устройство: {device}\n") - for epoch in range(num_epoch): + for epoch in range(start_epoch, start_epoch + num_epoch): # Вызов callback перед эпохой for cb in callbacks: cb.on_epoch_begin(epoch, self)