Files
simple-llm/example/example_gpt.py

91 lines
3.6 KiB
Python
Raw Permalink Normal View History

"""
Пример использования GPT модели из simple_llm
"""
import torch
import os
from simple_llm.transformer.gpt import GPT
def use_numeric_generation(config, model):
"""Функция для числовой генерации"""
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
print(f"\nЧисловой ввод: {input_seq.tolist()[0]}")
print("\n=== Режимы генерации ===")
# 1. Жадная генерация
greedy_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=False)
print("\n1. Жадная генерация (детерминированная):")
print(greedy_output.tolist()[0])
# 2. Сэмплирование с температурой
torch.manual_seed(42)
temp_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
temperature=0.7)
print("\n2. Сэмплирование (температура=0.7):")
print(temp_output.tolist()[0])
# 3. Top-k сэмплирование
torch.manual_seed(42)
topk_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
top_k=50)
print("\n3. Top-k сэмплирование (k=50):")
print(topk_output.tolist()[0])
# 4. Nucleus (top-p) сэмплирование
try:
torch.manual_seed(42)
topp_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
top_p=0.9)
print("\n4. Nucleus сэмплирование (p=0.9):")
print(topp_output.tolist()[0])
except Exception as e:
print(f"\nОшибка при nucleus сэмплировании: {str(e)}")
print("Пропускаем этот режим генерации")
def main():
# Конфигурация модели
config = {
'vocab_size': 10000,
'max_seq_len': 256,
'emb_size': 512,
'num_heads': 8,
'head_size': 64,
'num_layers': 6,
'dropout': 0.1,
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
# 1. Инициализация модели
print("Инициализация GPT модели...")
model = GPT(**config)
print(f"Модель создана на устройстве: {config['device']}")
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
# 2. Пример генерации
print("\nИспользуется числовая генерация...")
use_numeric_generation(config, model)
# 3. Сохранение и загрузка модели
print("\nТест сохранения/загрузки...")
import tempfile
with tempfile.NamedTemporaryFile() as tmp:
model.save(tmp.name)
print(f"Модель сохранена во временный файл: {tmp.name}")
loaded_model = GPT.load(tmp.name, device=config['device'])
print("Модель успешно загружена")
test_output = loaded_model(torch.randint(0, config['vocab_size'], (1, 5)).to(config['device']))
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
if __name__ == "__main__":
main()