mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 13:03:55 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -30,6 +30,7 @@ source venv/bin/activate # Linux/Mac
|
||||
# 3. Установите зависимости
|
||||
pip install torch
|
||||
pip install dill tqdm # Основные зависимости для работы
|
||||
pip install pytest
|
||||
|
||||
# Установка simple_llm пакета
|
||||
pip install .
|
||||
@@ -94,8 +95,10 @@ python bin/train_gpt_model.py \
|
||||
--epochs 3 \
|
||||
--emb-size 64 \
|
||||
--num-heads 2 \
|
||||
--num-layers 2
|
||||
--num-layers 2 \
|
||||
--keep-last-n 3
|
||||
```
|
||||
> Аргумент `--keep-last-n` определяет, сколько последних чекпоинтов (snapshot'ов обучения) хранить в папке модели. По умолчанию — 3. Старые файлы удаляются автоматически для экономии места.
|
||||
|
||||
### ✔️ Восстановление обучения с чекпоинта
|
||||
|
||||
@@ -130,7 +133,7 @@ python bin/generate_text.py \
|
||||
# Последовательно выполните:
|
||||
./bin/train_tokenizer.py --corpus data/corpus/sample --output data/tokenizer/bpe.json
|
||||
./bin/tokenize_corpus.py --corpus data/corpus/sample --tokenizer data/tokenizer/bpe.json
|
||||
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json
|
||||
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json --keep-last-n 3
|
||||
./bin/generate_text.py --model data/model/gpt_model.pth --tokenizer data/tokenizer/bpe.json --prompt "Привет"
|
||||
```
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ def main():
|
||||
parser.add_argument('--lr', type=float, default=0.0001,
|
||||
help='Learning rate')
|
||||
|
||||
parser.add_argument('--keep-last-n', type=int, default=3, help='Сколько последних чекпоинтов хранить')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Проверяем и создаем директорию для сохранения
|
||||
@@ -88,7 +89,8 @@ def main():
|
||||
learning_rate=args.lr,
|
||||
checkpoint_dir=output_dir,
|
||||
resume_training=True,
|
||||
start_epoch=start_epoch
|
||||
start_epoch=start_epoch,
|
||||
keep_last_n=args.keep_last_n
|
||||
)
|
||||
torch.save(model.state_dict(), args.output)
|
||||
|
||||
|
||||
@@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback):
|
||||
save_best_only (bool): Если True, сохраняет только при улучшении loss
|
||||
save_freq (int): Сохранять каждые N эпох (default=1)
|
||||
monitor (str): Какой loss мониторить ('val' или 'train')
|
||||
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
|
||||
"""
|
||||
def __init__(self,
|
||||
save_dir: str,
|
||||
save_best_only: bool = True,
|
||||
save_freq: int = 1,
|
||||
monitor: str = 'val'):
|
||||
monitor: str = 'val',
|
||||
keep_last_n: int = 3):
|
||||
self.save_dir = save_dir
|
||||
self.save_best_only = save_best_only
|
||||
self.save_freq = save_freq
|
||||
self.monitor = monitor
|
||||
self.keep_last_n = keep_last_n
|
||||
self.best_loss = float('inf')
|
||||
|
||||
# Создаем директорию если её нет
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
def on_epoch_end(self, epoch, model, train_loss, val_loss):
|
||||
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
|
||||
# Решаем какой loss использовать для сравнения
|
||||
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
|
||||
|
||||
# Сохраняем по расписанию или при улучшении
|
||||
should_save = (
|
||||
(epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
|
||||
)
|
||||
|
||||
@@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback):
|
||||
self.best_loss = current_loss
|
||||
checkpoint_path = os.path.join(
|
||||
self.save_dir,
|
||||
f"checkpoint_epoch_{epoch}.pt"
|
||||
f"checkpoint_epoch_{global_epoch}.pt"
|
||||
)
|
||||
|
||||
# Собираем состояния всех callback'ов
|
||||
@@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback):
|
||||
callback_states[cb.__class__.__name__] = cb.get_state()
|
||||
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'epoch': global_epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': model.optimizer.state_dict(),
|
||||
'train_loss': train_loss,
|
||||
@@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback):
|
||||
}, checkpoint_path)
|
||||
|
||||
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
|
||||
self._clean_old_checkpoints()
|
||||
|
||||
def _clean_old_checkpoints(self):
|
||||
import glob
|
||||
files = sorted(
|
||||
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
|
||||
key=os.path.getmtime
|
||||
)
|
||||
if len(files) > self.keep_last_n:
|
||||
for file in files[:-self.keep_last_n]:
|
||||
try:
|
||||
os.remove(file)
|
||||
print(f"Удалён старый чекпоинт: {file}")
|
||||
except Exception as e:
|
||||
print(f"Ошибка при удалении чекпоинта {file}: {e}")
|
||||
@@ -20,23 +20,36 @@ class ResumeTrainingCallback(Callback):
|
||||
def on_train_begin(self, model):
|
||||
if not self.resume:
|
||||
return
|
||||
|
||||
checkpoint_path = self._find_latest_checkpoint()
|
||||
if checkpoint_path:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
try:
|
||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||
# Убедимся, что загружаем на правильное устройство
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
if 'optimizer_state_dict' in checkpoint:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict'] is not None:
|
||||
if hasattr(model, 'scheduler'):
|
||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
self.last_epoch = checkpoint.get('epoch', -1)
|
||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
|
||||
# Найти максимальный существующий checkpoint по файловой системе
|
||||
import glob, os
|
||||
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||
if cp_files:
|
||||
try:
|
||||
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||
except Exception:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
self.last_epoch = -1
|
||||
else:
|
||||
# Если файлов совсем нет
|
||||
self.last_epoch = -1
|
||||
|
||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||
if not os.path.exists(self.checkpoint_dir):
|
||||
|
||||
@@ -297,7 +297,8 @@ class GPT(nn.Module):
|
||||
callbacks=None,
|
||||
checkpoint_dir=None,
|
||||
resume_training=False,
|
||||
start_epoch=0
|
||||
start_epoch=0,
|
||||
keep_last_n: int = 3
|
||||
):
|
||||
"""Обучает модель GPT с поддержкой callback-системы.
|
||||
|
||||
@@ -379,7 +380,7 @@ class GPT(nn.Module):
|
||||
])
|
||||
|
||||
if checkpoint_dir:
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n))
|
||||
|
||||
# Инициализация оптимизатора ДО вызова callback'ов
|
||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
@@ -391,19 +392,29 @@ class GPT(nn.Module):
|
||||
start_epoch = cb.last_epoch + 1
|
||||
break
|
||||
|
||||
# Начало обучения (после инициализации оптимизатора)
|
||||
for cb in callbacks:
|
||||
cb.on_train_begin(self)
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
||||
# Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл)
|
||||
import glob, os
|
||||
if checkpoint_dir is not None:
|
||||
cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||
if cp_files:
|
||||
try:
|
||||
max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||
if max_cp + 1 > start_epoch:
|
||||
start_epoch = max_cp + 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
final_epoch_num = start_epoch + num_epoch
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})")
|
||||
print(f"Размер батча: {train_loader.batch_size}")
|
||||
print(f"Всего батчей: {len(train_loader)}")
|
||||
print(f"Устройство: {device}\n")
|
||||
|
||||
for epoch in range(start_epoch, start_epoch + num_epoch):
|
||||
for epoch_local in range(num_epoch):
|
||||
global_epoch = start_epoch + epoch_local
|
||||
# Вызов callback перед эпохой
|
||||
for cb in callbacks:
|
||||
cb.on_epoch_begin(epoch, self)
|
||||
cb.on_epoch_begin(global_epoch, self)
|
||||
|
||||
self.train()
|
||||
epoch_loss = 0.0
|
||||
@@ -411,7 +422,7 @@ class GPT(nn.Module):
|
||||
|
||||
# Прогресс-бар для батчей
|
||||
batch_pbar = tqdm(train_loader,
|
||||
desc=f"Эпоха {epoch+1}/{num_epoch}",
|
||||
desc=f"Эпоха {global_epoch+1}/{final_epoch_num}",
|
||||
leave=False,
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
||||
|
||||
@@ -447,7 +458,7 @@ class GPT(nn.Module):
|
||||
self.train_loss = epoch_loss / len(train_loader)
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
|
||||
print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек")
|
||||
print(f"Средний Train Loss: {self.train_loss:.4f}")
|
||||
|
||||
if valid_loader is not None:
|
||||
@@ -456,7 +467,7 @@ class GPT(nn.Module):
|
||||
with torch.no_grad():
|
||||
# Прогресс-бар для валидации
|
||||
valid_pbar = tqdm(valid_loader,
|
||||
desc=f"Валидация {epoch+1}/{num_epoch}",
|
||||
desc=f"Валидация {global_epoch+1}/{final_epoch_num}",
|
||||
leave=False)
|
||||
|
||||
for inputs, targets in valid_pbar:
|
||||
@@ -473,11 +484,17 @@ class GPT(nn.Module):
|
||||
self.validation_loss = valid_loss / len(valid_loader)
|
||||
print(f"Средний Val Loss: {self.validation_loss:.4f}")
|
||||
|
||||
# Вызов callback после эпохи
|
||||
# Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта)
|
||||
stop_training = False
|
||||
for cb in callbacks:
|
||||
if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
# ModelCheckpointCallback — строго global_epoch
|
||||
if isinstance(cb, ModelCheckpointCallback):
|
||||
if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
else:
|
||||
# Старые/другие коллбэки — epoch_local
|
||||
if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss):
|
||||
stop_training = True
|
||||
|
||||
if stop_training:
|
||||
print("Обучение остановлено callback-ом")
|
||||
|
||||
@@ -49,7 +49,11 @@ def test_resume_training(sample_model, sample_data):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||
|
||||
|
||||
# Проверим, что чекпоинт создан
|
||||
files = os.listdir(tmpdir)
|
||||
assert any(f.startswith("checkpoint_epoch_") for f in files)
|
||||
|
||||
new_model = GPT(
|
||||
vocab_size=100,
|
||||
max_seq_len=10,
|
||||
@@ -58,18 +62,21 @@ def test_resume_training(sample_model, sample_data):
|
||||
head_size=8,
|
||||
num_layers=2
|
||||
)
|
||||
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
resume_cb.on_train_begin(new_model)
|
||||
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
|
||||
assert resume_cb.last_epoch == 0
|
||||
|
||||
# Дальнейшее обучение с fit/resume
|
||||
new_model.fit(
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
sample_data,
|
||||
num_epoch=2,
|
||||
callbacks=[resume_cb, checkpoint_cb],
|
||||
resume_training=True
|
||||
)
|
||||
|
||||
assert resume_cb.last_epoch == 0
|
||||
files = os.listdir(tmpdir)
|
||||
assert len(files) == 2
|
||||
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
|
||||
assert len(files) == 3
|
||||
|
||||
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
||||
"""Тестирует поведение при отсутствии чекпоинтов"""
|
||||
@@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
||||
|
||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
sample_model.fit(
|
||||
sample_data,
|
||||
num_epoch=1,
|
||||
callbacks=[resume_cb],
|
||||
resume_training=True
|
||||
)
|
||||
assert resume_cb.last_epoch == -1
|
||||
|
||||
def test_optimizer_state_restoration(sample_model, sample_data):
|
||||
"""Тестирует восстановление состояния оптимизатора"""
|
||||
|
||||
Reference in New Issue
Block a user