Files
simple-llm/example/generate_text.py

51 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Генерация текста с помощью обученной GPT-модели и токенизатора
"""
import torch
from simple_llm.transformer.gpt import GPT
from simple_llm.tokenizer.bpe import BPE
if __name__ == "__main__":
import torch
# Определяем устройство
#if torch.cuda.is_available():
# device = 'cuda'
#elif getattr(torch.backends, 'mps', None) and torch.backends.mps.is_available():
# device = 'mps' # Apple Silicon
#else:
# device = 'cpu'
device = 'cpu'
print(f"Используется устройство: {device}")
# Загрузим токенизатор и модель
tokenizer = BPE.load('data/tokenizer/bpe_tokenizer.json')
model = GPT(
vocab_size=tokenizer.vocab_size,
max_seq_len=64,
emb_size=256,
num_heads=4,
head_size=64,
num_layers=4,
device=device
)
model.load_state_dict(torch.load('data/model/simple_llm_gpt.pth', map_location=device))
model.eval()
# Введите начальный текст
prompt = "Привет, мир! "
prompt_tokens = tokenizer.encode(prompt)
print(f"Токены prompt: {prompt_tokens}")
print(f"Размер словаря токенизатора: {tokenizer.vocab_size}")
if any(idx >= tokenizer.vocab_size or idx < 0 for idx in prompt_tokens):
print("ВНИМАНИЕ: В prompt есть токены с индексом вне диапазона словаря! Генерация невозможна.")
exit(1)
input_ids = torch.tensor([prompt_tokens], device=device)
output = model.generate(
x=input_ids,
max_new_tokens=30,
do_sample=True,
temperature=1.0
)
result = tokenizer.decode(output[0].tolist())
print("Сгенерированный текст:", result)