mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
|
|
import os
|
|||
|
|
import tempfile
|
|||
|
|
import pytest
|
|||
|
|
import torch
|
|||
|
|
from simple_llm.transformer.gpt import GPT
|
|||
|
|
|
|||
|
|
@pytest.mark.skip(reason="Пропуск тестов сохранения/загрузки для ускорения проверки")
|
|||
|
|
def test_save_load():
|
|||
|
|
"""Тестирование сохранения и загрузки модели GPT"""
|
|||
|
|
# Инициализация параметров модели
|
|||
|
|
vocab_size = 1000
|
|||
|
|
max_seq_len = 128
|
|||
|
|
emb_size = 256
|
|||
|
|
num_heads = 4
|
|||
|
|
head_size = 64
|
|||
|
|
num_layers = 3
|
|||
|
|
|
|||
|
|
# Создаем модель
|
|||
|
|
model = GPT(
|
|||
|
|
vocab_size=vocab_size,
|
|||
|
|
max_seq_len=max_seq_len,
|
|||
|
|
emb_size=emb_size,
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
head_size=head_size,
|
|||
|
|
num_layers=num_layers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Создаем временный файл
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|||
|
|
temp_path = tmp_file.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# Тестируем сохранение
|
|||
|
|
model.save(temp_path)
|
|||
|
|
assert os.path.exists(temp_path), "Файл модели не был создан"
|
|||
|
|
|
|||
|
|
# Тестируем загрузку
|
|||
|
|
loaded_model = GPT.load(temp_path, device='cpu')
|
|||
|
|
|
|||
|
|
# Проверяем, что параметры загружены корректно через проверку конфигурации модели
|
|||
|
|
assert loaded_model._token_embeddings.num_embeddings == vocab_size
|
|||
|
|
assert loaded_model.max_seq_len == max_seq_len
|
|||
|
|
assert loaded_model._token_embeddings.embedding_dim == emb_size
|
|||
|
|
assert len(loaded_model._decoders) == num_layers
|
|||
|
|
|
|||
|
|
# Проверяем, что веса загрузились корректно
|
|||
|
|
for (name1, param1), (name2, param2) in zip(
|
|||
|
|
model.named_parameters(),
|
|||
|
|
loaded_model.named_parameters()
|
|||
|
|
):
|
|||
|
|
assert name1 == name2, "Имена параметров не совпадают"
|
|||
|
|
assert torch.allclose(param1, param2), f"Параметры {name1} не совпадают"
|
|||
|
|
|
|||
|
|
# Проверяем работу загруженной модели
|
|||
|
|
test_input = torch.randint(0, vocab_size, (1, 10))
|
|||
|
|
with torch.no_grad():
|
|||
|
|
torch.manual_seed(42) # Фиксируем seed для воспроизводимости
|
|||
|
|
original_output = model(test_input)
|
|||
|
|
torch.manual_seed(42)
|
|||
|
|
loaded_output = loaded_model(test_input)
|
|||
|
|
assert torch.allclose(original_output, loaded_output, atol=1e-6), "Выходы моделей не совпадают"
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
# Удаляем временный файл
|
|||
|
|
if os.path.exists(temp_path):
|
|||
|
|
os.remove(temp_path)
|
|||
|
|
|
|||
|
|
@pytest.mark.skip(reason="Пропуск тестов сохранения/загрузки для ускорения проверки")
|
|||
|
|
def test_save_load_with_generation():
|
|||
|
|
"""Тестирование генерации после загрузки модели"""
|
|||
|
|
vocab_size = 1000
|
|||
|
|
max_seq_len = 128
|
|||
|
|
emb_size = 256
|
|||
|
|
num_heads = 4
|
|||
|
|
head_size = 64
|
|||
|
|
num_layers = 2
|
|||
|
|
|
|||
|
|
model = GPT(
|
|||
|
|
vocab_size=vocab_size,
|
|||
|
|
max_seq_len=max_seq_len,
|
|||
|
|
emb_size=emb_size,
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
head_size=head_size,
|
|||
|
|
num_layers=num_layers
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
|
|||
|
|
temp_path = tmp_file.name
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
model.save(temp_path)
|
|||
|
|
loaded_model = GPT.load(temp_path, device='cpu')
|
|||
|
|
|
|||
|
|
# Тестируем генерацию
|
|||
|
|
input_seq = torch.randint(0, vocab_size, (1, 5))
|
|||
|
|
original_gen = model.generate(input_seq, max_new_tokens=10)
|
|||
|
|
loaded_gen = loaded_model.generate(input_seq, max_new_tokens=10)
|
|||
|
|
|
|||
|
|
assert original_gen.shape == loaded_gen.shape, "Размеры сгенерированных последовательностей не совпадают"
|
|||
|
|
assert torch.all(original_gen == loaded_gen), "Сгенерированные последовательности не совпадают"
|
|||
|
|
|
|||
|
|
finally:
|
|||
|
|
if os.path.exists(temp_path):
|
|||
|
|
os.remove(temp_path)
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
test_save_load()
|
|||
|
|
test_save_load_with_generation()
|
|||
|
|
print("Все тесты прошли успешно!")
|