diff --git a/bin/train_gpt_model.py b/bin/train_gpt_model.py index 9e292a5..42578cd 100755 --- a/bin/train_gpt_model.py +++ b/bin/train_gpt_model.py @@ -76,7 +76,8 @@ def main(): model.fit( train_loader=loader, num_epoch=args.epochs, - learning_rate=args.lr + learning_rate=args.lr, + checkpoint_dir=output_dir ) torch.save(model.state_dict(), args.output) diff --git a/simple_llm/transformer/callback/__init__.py b/simple_llm/transformer/callback/__init__.py new file mode 100644 index 0000000..c2fbc19 --- /dev/null +++ b/simple_llm/transformer/callback/__init__.py @@ -0,0 +1,20 @@ +""" +Callback-система для управления обучением GPT. + +Доступные callback-и: +- EarlyStoppingCallback - ранняя остановка +- ModelCheckpointCallback - сохранение чекпоинтов +- LRSchedulerCallback - регулировка learning rate +""" + +from .callback import Callback +from .early_stopping_callback import EarlyStoppingCallback +from .lrs_scheduler_callback import LRSchedulerCallback +from .model_checkpoint_callback import ModelCheckpointCallback + +__all__ = [ + 'Callback', + 'EarlyStoppingCallback', + 'LRSchedulerCallback', + 'ModelCheckpointCallback' +] \ No newline at end of file diff --git a/simple_llm/transformer/callback/callback.py b/simple_llm/transformer/callback/callback.py new file mode 100644 index 0000000..3b506cb --- /dev/null +++ b/simple_llm/transformer/callback/callback.py @@ -0,0 +1,47 @@ +""" +Callback-система для управления процессом обучения GPT. + +Реализует паттерн Observer для мониторинга и управления обучением. +""" + +class Callback: + """Абстрактный базовый класс для всех callback-ов. + + Методы вызываются автоматически во время обучения: + - on_epoch_begin - перед началом эпохи + - on_batch_end - после обработки батча + - on_epoch_end - в конце эпохи + """ + + def on_epoch_begin(self, epoch, model): + """Вызывается перед началом эпохи. + + Args: + epoch (int): Номер текущей эпохи (0-based) + model (GPT): Обучаемая модель GPT + """ + pass + + def on_batch_end(self, batch, model, loss): + """Вызывается после обработки каждого батча. + + Args: + batch (int): Номер батча в текущей эпохе + model (GPT): Обучаемая модель GPT + loss (float): Значение функции потерь на батче + """ + pass + + def on_epoch_end(self, epoch, model, train_loss, val_loss): + """Вызывается в конце эпохи. + + Args: + epoch (int): Номер завершенной эпохи + model (GPT): Обучаемая модель GPT + train_loss (float): Средний loss на обучении + val_loss (float|None): Средний loss на валидации или None + + Returns: + bool: Если True, обучение будет прервано + """ + pass \ No newline at end of file diff --git a/simple_llm/transformer/callback/early_stopping_callback.py b/simple_llm/transformer/callback/early_stopping_callback.py new file mode 100644 index 0000000..c61866c --- /dev/null +++ b/simple_llm/transformer/callback/early_stopping_callback.py @@ -0,0 +1,32 @@ +from .callback import Callback +import torch + +class EarlyStoppingCallback(Callback): + """Останавливает обучение, если loss не улучшается. + + Пример: + >>> early_stopping = EarlyStoppingCallback(patience=3) + >>> model.fit(callbacks=[early_stopping]) + + Args: + patience (int): Количество эпох без улучшения + min_delta (float): Минимальное значимое улучшение loss + """ + def __init__(self, patience=3, min_delta=0.01): + self.patience = patience + self.min_delta = min_delta + self.counter = 0 + self.best_loss = float('inf') + + def on_epoch_end(self, epoch, model, train_loss, val_loss): + current_loss = val_loss if val_loss else train_loss + + if (self.best_loss - current_loss) > self.min_delta: + self.best_loss = current_loss + self.counter = 0 + else: + self.counter += 1 + if self.counter >= self.patience: + print(f"Early stopping triggered at epoch {epoch}") + return True # Остановить обучение + return False \ No newline at end of file diff --git a/simple_llm/transformer/callback/lrs_scheduler_callback.py b/simple_llm/transformer/callback/lrs_scheduler_callback.py new file mode 100644 index 0000000..ac35f8d --- /dev/null +++ b/simple_llm/transformer/callback/lrs_scheduler_callback.py @@ -0,0 +1,21 @@ +from .callback import Callback + +class LRSchedulerCallback(Callback): + """Динамически регулирует learning rate. + + Пример: + >>> lr_scheduler = LRSchedulerCallback(lr=0.001) + >>> model.fit(callbacks=[lr_scheduler]) + + Args: + lr (float): Начальный learning rate + decay (float): Коэффициент уменьшения LR + """ + def __init__(self, lr, decay=0.95): + self.base_lr = lr + self.decay = decay + + def on_epoch_begin(self, epoch, model): + new_lr = self.base_lr * (self.decay ** epoch) + for param_group in model.optimizer.param_groups: + param_group['lr'] = new_lr \ No newline at end of file diff --git a/simple_llm/transformer/callback/model_checkpoint_callback.py b/simple_llm/transformer/callback/model_checkpoint_callback.py new file mode 100644 index 0000000..3f40f8a --- /dev/null +++ b/simple_llm/transformer/callback/model_checkpoint_callback.py @@ -0,0 +1,34 @@ +from .callback import Callback +import torch +import os + +class ModelCheckpointCallback(Callback): + """Сохраняет чекпоинты модели во время обучения. + + Пример: + >>> checkpoint = ModelCheckpointCallback('checkpoints/') + >>> model.fit(callbacks=[checkpoint]) + + Args: + save_dir (str): Директория для сохранения + save_best_only (bool): Сохранять только лучшие модели + """ + def __init__(self, save_dir, save_best_only=True): + self.save_dir = save_dir + self.save_best_only = save_best_only + self.best_loss = float('inf') + + def on_epoch_end(self, epoch, model, train_loss, val_loss): + if not os.path.exists(self.save_dir): + os.makedirs(self.save_dir) + + current_loss = val_loss if val_loss else train_loss + + if not self.save_best_only or current_loss < self.best_loss: + self.best_loss = current_loss + path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pt") + torch.save({ + 'epoch': epoch, + 'model_state': model.state_dict(), + 'loss': current_loss + }, path) \ No newline at end of file