mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Fix checkpoint resume logic, global epoch numbering, and robust recovery; update tests for checkpointing; all tests passing
This commit is contained in:
@@ -30,6 +30,7 @@ source venv/bin/activate # Linux/Mac
|
|||||||
# 3. Установите зависимости
|
# 3. Установите зависимости
|
||||||
pip install torch
|
pip install torch
|
||||||
pip install dill tqdm # Основные зависимости для работы
|
pip install dill tqdm # Основные зависимости для работы
|
||||||
|
pip install pytest
|
||||||
|
|
||||||
# Установка simple_llm пакета
|
# Установка simple_llm пакета
|
||||||
pip install .
|
pip install .
|
||||||
@@ -94,8 +95,10 @@ python bin/train_gpt_model.py \
|
|||||||
--epochs 3 \
|
--epochs 3 \
|
||||||
--emb-size 64 \
|
--emb-size 64 \
|
||||||
--num-heads 2 \
|
--num-heads 2 \
|
||||||
--num-layers 2
|
--num-layers 2 \
|
||||||
|
--keep-last-n 3
|
||||||
```
|
```
|
||||||
|
> Аргумент `--keep-last-n` определяет, сколько последних чекпоинтов (snapshot'ов обучения) хранить в папке модели. По умолчанию — 3. Старые файлы удаляются автоматически для экономии места.
|
||||||
|
|
||||||
### ✔️ Восстановление обучения с чекпоинта
|
### ✔️ Восстановление обучения с чекпоинта
|
||||||
|
|
||||||
@@ -130,7 +133,7 @@ python bin/generate_text.py \
|
|||||||
# Последовательно выполните:
|
# Последовательно выполните:
|
||||||
./bin/train_tokenizer.py --corpus data/corpus/sample --output data/tokenizer/bpe.json
|
./bin/train_tokenizer.py --corpus data/corpus/sample --output data/tokenizer/bpe.json
|
||||||
./bin/tokenize_corpus.py --corpus data/corpus/sample --tokenizer data/tokenizer/bpe.json
|
./bin/tokenize_corpus.py --corpus data/corpus/sample --tokenizer data/tokenizer/bpe.json
|
||||||
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json
|
./bin/train_gpt_model.py --tokens data/tokens/corpus_tokens.pkl --tokenizer data/tokenizer/bpe.json --keep-last-n 3
|
||||||
./bin/generate_text.py --model data/model/gpt_model.pth --tokenizer data/tokenizer/bpe.json --prompt "Привет"
|
./bin/generate_text.py --model data/model/gpt_model.pth --tokenizer data/tokenizer/bpe.json --prompt "Привет"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ def main():
|
|||||||
parser.add_argument('--lr', type=float, default=0.0001,
|
parser.add_argument('--lr', type=float, default=0.0001,
|
||||||
help='Learning rate')
|
help='Learning rate')
|
||||||
|
|
||||||
|
parser.add_argument('--keep-last-n', type=int, default=3, help='Сколько последних чекпоинтов хранить')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Проверяем и создаем директорию для сохранения
|
# Проверяем и создаем директорию для сохранения
|
||||||
@@ -88,7 +89,8 @@ def main():
|
|||||||
learning_rate=args.lr,
|
learning_rate=args.lr,
|
||||||
checkpoint_dir=output_dir,
|
checkpoint_dir=output_dir,
|
||||||
resume_training=True,
|
resume_training=True,
|
||||||
start_epoch=start_epoch
|
start_epoch=start_epoch,
|
||||||
|
keep_last_n=args.keep_last_n
|
||||||
)
|
)
|
||||||
torch.save(model.state_dict(), args.output)
|
torch.save(model.state_dict(), args.output)
|
||||||
|
|
||||||
|
|||||||
@@ -15,28 +15,31 @@ class ModelCheckpointCallback(Callback):
|
|||||||
save_best_only (bool): Если True, сохраняет только при улучшении loss
|
save_best_only (bool): Если True, сохраняет только при улучшении loss
|
||||||
save_freq (int): Сохранять каждые N эпох (default=1)
|
save_freq (int): Сохранять каждые N эпох (default=1)
|
||||||
monitor (str): Какой loss мониторить ('val' или 'train')
|
monitor (str): Какой loss мониторить ('val' или 'train')
|
||||||
|
keep_last_n (int): Сколько последних чекпоинтов хранить на диске (по умолчанию 3)
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
save_dir: str,
|
save_dir: str,
|
||||||
save_best_only: bool = True,
|
save_best_only: bool = True,
|
||||||
save_freq: int = 1,
|
save_freq: int = 1,
|
||||||
monitor: str = 'val'):
|
monitor: str = 'val',
|
||||||
|
keep_last_n: int = 3):
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
self.save_best_only = save_best_only
|
self.save_best_only = save_best_only
|
||||||
self.save_freq = save_freq
|
self.save_freq = save_freq
|
||||||
self.monitor = monitor
|
self.monitor = monitor
|
||||||
|
self.keep_last_n = keep_last_n
|
||||||
self.best_loss = float('inf')
|
self.best_loss = float('inf')
|
||||||
|
|
||||||
# Создаем директорию если её нет
|
# Создаем директорию если её нет
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
def on_epoch_end(self, epoch, model, train_loss, val_loss):
|
def on_epoch_end(self, global_epoch, model, train_loss, val_loss):
|
||||||
# Решаем какой loss использовать для сравнения
|
# Решаем какой loss использовать для сравнения
|
||||||
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
|
current_loss = val_loss if (self.monitor == 'val' and val_loss is not None) else train_loss
|
||||||
|
|
||||||
# Сохраняем по расписанию или при улучшении
|
# Сохраняем по расписанию или при улучшении
|
||||||
should_save = (
|
should_save = (
|
||||||
(epoch + 1) % self.save_freq == 0 or # по расписанию
|
(global_epoch + 1) % self.save_freq == 0 or # по расписанию
|
||||||
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
|
(self.save_best_only and current_loss < self.best_loss) # или если это лучшая модель
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -44,7 +47,7 @@ class ModelCheckpointCallback(Callback):
|
|||||||
self.best_loss = current_loss
|
self.best_loss = current_loss
|
||||||
checkpoint_path = os.path.join(
|
checkpoint_path = os.path.join(
|
||||||
self.save_dir,
|
self.save_dir,
|
||||||
f"checkpoint_epoch_{epoch}.pt"
|
f"checkpoint_epoch_{global_epoch}.pt"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Собираем состояния всех callback'ов
|
# Собираем состояния всех callback'ов
|
||||||
@@ -55,7 +58,7 @@ class ModelCheckpointCallback(Callback):
|
|||||||
callback_states[cb.__class__.__name__] = cb.get_state()
|
callback_states[cb.__class__.__name__] = cb.get_state()
|
||||||
|
|
||||||
torch.save({
|
torch.save({
|
||||||
'epoch': epoch,
|
'epoch': global_epoch,
|
||||||
'model_state_dict': model.state_dict(),
|
'model_state_dict': model.state_dict(),
|
||||||
'optimizer_state_dict': model.optimizer.state_dict(),
|
'optimizer_state_dict': model.optimizer.state_dict(),
|
||||||
'train_loss': train_loss,
|
'train_loss': train_loss,
|
||||||
@@ -73,3 +76,18 @@ class ModelCheckpointCallback(Callback):
|
|||||||
}, checkpoint_path)
|
}, checkpoint_path)
|
||||||
|
|
||||||
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
|
print(f"Модель сохранена в {checkpoint_path} (loss: {current_loss:.4f})")
|
||||||
|
self._clean_old_checkpoints()
|
||||||
|
|
||||||
|
def _clean_old_checkpoints(self):
|
||||||
|
import glob
|
||||||
|
files = sorted(
|
||||||
|
glob.glob(os.path.join(self.save_dir, 'checkpoint_epoch_*.pt')),
|
||||||
|
key=os.path.getmtime
|
||||||
|
)
|
||||||
|
if len(files) > self.keep_last_n:
|
||||||
|
for file in files[:-self.keep_last_n]:
|
||||||
|
try:
|
||||||
|
os.remove(file)
|
||||||
|
print(f"Удалён старый чекпоинт: {file}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Ошибка при удалении чекпоинта {file}: {e}")
|
||||||
@@ -20,12 +20,11 @@ class ResumeTrainingCallback(Callback):
|
|||||||
def on_train_begin(self, model):
|
def on_train_begin(self, model):
|
||||||
if not self.resume:
|
if not self.resume:
|
||||||
return
|
return
|
||||||
|
|
||||||
checkpoint_path = self._find_latest_checkpoint()
|
checkpoint_path = self._find_latest_checkpoint()
|
||||||
if checkpoint_path:
|
if checkpoint_path:
|
||||||
|
try:
|
||||||
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
print(f"\n⚡ Восстанавливаем обучение из {checkpoint_path}")
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
checkpoint = torch.load(checkpoint_path, map_location=model._device)
|
||||||
|
|
||||||
# Убедимся, что загружаем на правильное устройство
|
# Убедимся, что загружаем на правильное устройство
|
||||||
model.load_state_dict(checkpoint['model_state_dict'])
|
model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
if 'optimizer_state_dict' in checkpoint:
|
if 'optimizer_state_dict' in checkpoint:
|
||||||
@@ -34,9 +33,23 @@ class ResumeTrainingCallback(Callback):
|
|||||||
if hasattr(model, 'scheduler'):
|
if hasattr(model, 'scheduler'):
|
||||||
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
model.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||||
self.last_epoch = checkpoint.get('epoch', -1)
|
self.last_epoch = checkpoint.get('epoch', -1)
|
||||||
|
|
||||||
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
print(f"➔ Продолжаем с эпохи {self.last_epoch + 1}")
|
||||||
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
print(f"➔ Последний loss: {checkpoint.get('train_loss', 'N/A'):.4f}\n")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"⚠️ Чекпоинт поврежден или не читается: {checkpoint_path}\n{e}")
|
||||||
|
# Найти максимальный существующий checkpoint по файловой системе
|
||||||
|
import glob, os
|
||||||
|
cp_files = glob.glob(os.path.join(self.checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||||
|
if cp_files:
|
||||||
|
try:
|
||||||
|
self.last_epoch = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||||
|
except Exception:
|
||||||
|
self.last_epoch = -1
|
||||||
|
else:
|
||||||
|
self.last_epoch = -1
|
||||||
|
else:
|
||||||
|
# Если файлов совсем нет
|
||||||
|
self.last_epoch = -1
|
||||||
|
|
||||||
def _find_latest_checkpoint(self) -> Optional[str]:
|
def _find_latest_checkpoint(self) -> Optional[str]:
|
||||||
if not os.path.exists(self.checkpoint_dir):
|
if not os.path.exists(self.checkpoint_dir):
|
||||||
|
|||||||
@@ -297,7 +297,8 @@ class GPT(nn.Module):
|
|||||||
callbacks=None,
|
callbacks=None,
|
||||||
checkpoint_dir=None,
|
checkpoint_dir=None,
|
||||||
resume_training=False,
|
resume_training=False,
|
||||||
start_epoch=0
|
start_epoch=0,
|
||||||
|
keep_last_n: int = 3
|
||||||
):
|
):
|
||||||
"""Обучает модель GPT с поддержкой callback-системы.
|
"""Обучает модель GPT с поддержкой callback-системы.
|
||||||
|
|
||||||
@@ -379,7 +380,7 @@ class GPT(nn.Module):
|
|||||||
])
|
])
|
||||||
|
|
||||||
if checkpoint_dir:
|
if checkpoint_dir:
|
||||||
callbacks.append(ModelCheckpointCallback(checkpoint_dir))
|
callbacks.append(ModelCheckpointCallback(checkpoint_dir, keep_last_n=keep_last_n))
|
||||||
|
|
||||||
# Инициализация оптимизатора ДО вызова callback'ов
|
# Инициализация оптимизатора ДО вызова callback'ов
|
||||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||||
@@ -391,19 +392,29 @@ class GPT(nn.Module):
|
|||||||
start_epoch = cb.last_epoch + 1
|
start_epoch = cb.last_epoch + 1
|
||||||
break
|
break
|
||||||
|
|
||||||
# Начало обучения (после инициализации оптимизатора)
|
# Проверка через реальные файлы чекпоинтов (если они удалены или keep_last_n, брать max файл)
|
||||||
for cb in callbacks:
|
import glob, os
|
||||||
cb.on_train_begin(self)
|
if checkpoint_dir is not None:
|
||||||
|
cp_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint_epoch_*.pt'))
|
||||||
|
if cp_files:
|
||||||
|
try:
|
||||||
|
max_cp = max([int(os.path.splitext(os.path.basename(f))[0].split('_')[-1]) for f in cp_files])
|
||||||
|
if max_cp + 1 > start_epoch:
|
||||||
|
start_epoch = max_cp + 1
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch} по {start_epoch + num_epoch - 1})")
|
final_epoch_num = start_epoch + num_epoch
|
||||||
|
print(f"\nНачало обучения GPT на {num_epoch} эпох (с {start_epoch + 1} по {final_epoch_num})")
|
||||||
print(f"Размер батча: {train_loader.batch_size}")
|
print(f"Размер батча: {train_loader.batch_size}")
|
||||||
print(f"Всего батчей: {len(train_loader)}")
|
print(f"Всего батчей: {len(train_loader)}")
|
||||||
print(f"Устройство: {device}\n")
|
print(f"Устройство: {device}\n")
|
||||||
|
|
||||||
for epoch in range(start_epoch, start_epoch + num_epoch):
|
for epoch_local in range(num_epoch):
|
||||||
|
global_epoch = start_epoch + epoch_local
|
||||||
# Вызов callback перед эпохой
|
# Вызов callback перед эпохой
|
||||||
for cb in callbacks:
|
for cb in callbacks:
|
||||||
cb.on_epoch_begin(epoch, self)
|
cb.on_epoch_begin(global_epoch, self)
|
||||||
|
|
||||||
self.train()
|
self.train()
|
||||||
epoch_loss = 0.0
|
epoch_loss = 0.0
|
||||||
@@ -411,7 +422,7 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
# Прогресс-бар для батчей
|
# Прогресс-бар для батчей
|
||||||
batch_pbar = tqdm(train_loader,
|
batch_pbar = tqdm(train_loader,
|
||||||
desc=f"Эпоха {epoch+1}/{num_epoch}",
|
desc=f"Эпоха {global_epoch+1}/{final_epoch_num}",
|
||||||
leave=False,
|
leave=False,
|
||||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
||||||
|
|
||||||
@@ -447,7 +458,7 @@ class GPT(nn.Module):
|
|||||||
self.train_loss = epoch_loss / len(train_loader)
|
self.train_loss = epoch_loss / len(train_loader)
|
||||||
epoch_time = time.time() - start_time
|
epoch_time = time.time() - start_time
|
||||||
|
|
||||||
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
|
print(f"\nЭпоха {global_epoch+1}/{final_epoch_num} завершена за {epoch_time:.2f} сек")
|
||||||
print(f"Средний Train Loss: {self.train_loss:.4f}")
|
print(f"Средний Train Loss: {self.train_loss:.4f}")
|
||||||
|
|
||||||
if valid_loader is not None:
|
if valid_loader is not None:
|
||||||
@@ -456,7 +467,7 @@ class GPT(nn.Module):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Прогресс-бар для валидации
|
# Прогресс-бар для валидации
|
||||||
valid_pbar = tqdm(valid_loader,
|
valid_pbar = tqdm(valid_loader,
|
||||||
desc=f"Валидация {epoch+1}/{num_epoch}",
|
desc=f"Валидация {global_epoch+1}/{final_epoch_num}",
|
||||||
leave=False)
|
leave=False)
|
||||||
|
|
||||||
for inputs, targets in valid_pbar:
|
for inputs, targets in valid_pbar:
|
||||||
@@ -473,10 +484,16 @@ class GPT(nn.Module):
|
|||||||
self.validation_loss = valid_loss / len(valid_loader)
|
self.validation_loss = valid_loss / len(valid_loader)
|
||||||
print(f"Средний Val Loss: {self.validation_loss:.4f}")
|
print(f"Средний Val Loss: {self.validation_loss:.4f}")
|
||||||
|
|
||||||
# Вызов callback после эпохи
|
# Вызов callback после эпохи — теперь передаём global_epoch (гарантируем уникальность чекпоинта)
|
||||||
stop_training = False
|
stop_training = False
|
||||||
for cb in callbacks:
|
for cb in callbacks:
|
||||||
if cb.on_epoch_end(epoch, self, self.train_loss, self.validation_loss):
|
# ModelCheckpointCallback — строго global_epoch
|
||||||
|
if isinstance(cb, ModelCheckpointCallback):
|
||||||
|
if cb.on_epoch_end(global_epoch, self, self.train_loss, self.validation_loss):
|
||||||
|
stop_training = True
|
||||||
|
else:
|
||||||
|
# Старые/другие коллбэки — epoch_local
|
||||||
|
if cb.on_epoch_end(epoch_local, self, self.train_loss, self.validation_loss):
|
||||||
stop_training = True
|
stop_training = True
|
||||||
|
|
||||||
if stop_training:
|
if stop_training:
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ def test_resume_training(sample_model, sample_data):
|
|||||||
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
checkpoint_cb = ModelCheckpointCallback(tmpdir, save_best_only=False)
|
||||||
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
sample_model.fit(sample_data, num_epoch=1, callbacks=[checkpoint_cb])
|
||||||
|
|
||||||
|
# Проверим, что чекпоинт создан
|
||||||
|
files = os.listdir(tmpdir)
|
||||||
|
assert any(f.startswith("checkpoint_epoch_") for f in files)
|
||||||
|
|
||||||
new_model = GPT(
|
new_model = GPT(
|
||||||
vocab_size=100,
|
vocab_size=100,
|
||||||
max_seq_len=10,
|
max_seq_len=10,
|
||||||
@@ -60,16 +64,19 @@ def test_resume_training(sample_model, sample_data):
|
|||||||
)
|
)
|
||||||
|
|
||||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||||
|
resume_cb.on_train_begin(new_model)
|
||||||
|
# После создания чекпоинта и on_train_begin, last_epoch должен быть 0
|
||||||
|
assert resume_cb.last_epoch == 0
|
||||||
|
|
||||||
|
# Дальнейшее обучение с fit/resume
|
||||||
new_model.fit(
|
new_model.fit(
|
||||||
sample_data,
|
sample_data,
|
||||||
num_epoch=2,
|
num_epoch=2,
|
||||||
callbacks=[resume_cb, checkpoint_cb],
|
callbacks=[resume_cb, checkpoint_cb],
|
||||||
resume_training=True
|
resume_training=True
|
||||||
)
|
)
|
||||||
|
files = [f for f in os.listdir(tmpdir) if f.startswith('checkpoint_epoch_')]
|
||||||
assert resume_cb.last_epoch == 0
|
assert len(files) == 3
|
||||||
files = os.listdir(tmpdir)
|
|
||||||
assert len(files) == 2
|
|
||||||
|
|
||||||
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
def test_resume_with_missing_checkpoint(sample_model, sample_data):
|
||||||
"""Тестирует поведение при отсутствии чекпоинтов"""
|
"""Тестирует поведение при отсутствии чекпоинтов"""
|
||||||
@@ -95,13 +102,13 @@ def test_resume_with_corrupted_checkpoint(sample_model, sample_data):
|
|||||||
|
|
||||||
resume_cb = ResumeTrainingCallback(tmpdir)
|
resume_cb = ResumeTrainingCallback(tmpdir)
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
sample_model.fit(
|
sample_model.fit(
|
||||||
sample_data,
|
sample_data,
|
||||||
num_epoch=1,
|
num_epoch=1,
|
||||||
callbacks=[resume_cb],
|
callbacks=[resume_cb],
|
||||||
resume_training=True
|
resume_training=True
|
||||||
)
|
)
|
||||||
|
assert resume_cb.last_epoch == -1
|
||||||
|
|
||||||
def test_optimizer_state_restoration(sample_model, sample_data):
|
def test_optimizer_state_restoration(sample_model, sample_data):
|
||||||
"""Тестирует восстановление состояния оптимизатора"""
|
"""Тестирует восстановление состояния оптимизатора"""
|
||||||
|
|||||||
Reference in New Issue
Block a user