2025-07-22 06:24:46 +03:00
|
|
|
|
"""
|
|
|
|
|
|
Пример использования GPT модели из simple_llm
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
2025-07-22 10:53:57 +03:00
|
|
|
|
import os
|
2025-07-22 06:24:46 +03:00
|
|
|
|
from simple_llm.transformer.gpt import GPT
|
|
|
|
|
|
|
2025-07-22 10:53:57 +03:00
|
|
|
|
def use_numeric_generation(config, model):
|
|
|
|
|
|
"""Функция для числовой генерации"""
|
|
|
|
|
|
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
|
|
|
|
|
|
print(f"\nЧисловой ввод: {input_seq.tolist()[0]}")
|
|
|
|
|
|
|
|
|
|
|
|
print("\n=== Режимы генерации ===")
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Жадная генерация
|
|
|
|
|
|
greedy_output = model.generate(input_seq.clone(),
|
|
|
|
|
|
max_new_tokens=20,
|
|
|
|
|
|
do_sample=False)
|
|
|
|
|
|
print("\n1. Жадная генерация (детерминированная):")
|
|
|
|
|
|
print(greedy_output.tolist()[0])
|
|
|
|
|
|
|
|
|
|
|
|
# 2. Сэмплирование с температурой
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
temp_output = model.generate(input_seq.clone(),
|
|
|
|
|
|
max_new_tokens=20,
|
|
|
|
|
|
do_sample=True,
|
|
|
|
|
|
temperature=0.7)
|
|
|
|
|
|
print("\n2. Сэмплирование (температура=0.7):")
|
|
|
|
|
|
print(temp_output.tolist()[0])
|
|
|
|
|
|
|
|
|
|
|
|
# 3. Top-k сэмплирование
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
topk_output = model.generate(input_seq.clone(),
|
|
|
|
|
|
max_new_tokens=20,
|
|
|
|
|
|
do_sample=True,
|
|
|
|
|
|
top_k=50)
|
|
|
|
|
|
print("\n3. Top-k сэмплирование (k=50):")
|
|
|
|
|
|
print(topk_output.tolist()[0])
|
|
|
|
|
|
|
|
|
|
|
|
# 4. Nucleus (top-p) сэмплирование
|
|
|
|
|
|
try:
|
|
|
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
topp_output = model.generate(input_seq.clone(),
|
|
|
|
|
|
max_new_tokens=20,
|
|
|
|
|
|
do_sample=True,
|
|
|
|
|
|
top_p=0.9)
|
|
|
|
|
|
print("\n4. Nucleus сэмплирование (p=0.9):")
|
|
|
|
|
|
print(topp_output.tolist()[0])
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"\nОшибка при nucleus сэмплировании: {str(e)}")
|
|
|
|
|
|
print("Пропускаем этот режим генерации")
|
|
|
|
|
|
|
2025-07-22 06:24:46 +03:00
|
|
|
|
def main():
|
|
|
|
|
|
# Конфигурация модели
|
|
|
|
|
|
config = {
|
2025-07-22 10:53:57 +03:00
|
|
|
|
'vocab_size': 10000,
|
|
|
|
|
|
'max_seq_len': 256,
|
|
|
|
|
|
'emb_size': 512,
|
|
|
|
|
|
'num_heads': 8,
|
|
|
|
|
|
'head_size': 64,
|
|
|
|
|
|
'num_layers': 6,
|
|
|
|
|
|
'dropout': 0.1,
|
2025-07-22 06:24:46 +03:00
|
|
|
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 1. Инициализация модели
|
|
|
|
|
|
print("Инициализация GPT модели...")
|
|
|
|
|
|
model = GPT(**config)
|
|
|
|
|
|
print(f"Модель создана на устройстве: {config['device']}")
|
|
|
|
|
|
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
|
|
|
|
|
|
|
2025-07-22 10:53:57 +03:00
|
|
|
|
# 2. Пример генерации
|
|
|
|
|
|
print("\nИспользуется числовая генерация...")
|
|
|
|
|
|
use_numeric_generation(config, model)
|
2025-07-22 06:24:46 +03:00
|
|
|
|
|
|
|
|
|
|
# 3. Сохранение и загрузка модели
|
|
|
|
|
|
print("\nТест сохранения/загрузки...")
|
|
|
|
|
|
import tempfile
|
|
|
|
|
|
with tempfile.NamedTemporaryFile() as tmp:
|
|
|
|
|
|
model.save(tmp.name)
|
|
|
|
|
|
print(f"Модель сохранена во временный файл: {tmp.name}")
|
|
|
|
|
|
|
|
|
|
|
|
loaded_model = GPT.load(tmp.name, device=config['device'])
|
|
|
|
|
|
print("Модель успешно загружена")
|
|
|
|
|
|
|
2025-07-22 10:53:57 +03:00
|
|
|
|
test_output = loaded_model(torch.randint(0, config['vocab_size'], (1, 5)).to(config['device']))
|
2025-07-22 06:24:46 +03:00
|
|
|
|
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2025-07-22 10:53:57 +03:00
|
|
|
|
main()
|