mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
82 lines
3.5 KiB
Python
82 lines
3.5 KiB
Python
|
|
#!/usr/bin/env python3
|
|||
|
|
"""
|
|||
|
|
Генерация текста (финальная версия)
|
|||
|
|
"""
|
|||
|
|
import argparse
|
|||
|
|
import torch
|
|||
|
|
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
|||
|
|
from simple_llm.transformer.gpt import GPT
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
parser = argparse.ArgumentParser()
|
|||
|
|
# Обязательные параметры
|
|||
|
|
parser.add_argument('--model', type=str, required=True,
|
|||
|
|
help='Путь к файлу модели (.pth)')
|
|||
|
|
parser.add_argument('--tokenizer', type=str, required=True,
|
|||
|
|
help='Путь к файлу токенизатора (.json)')
|
|||
|
|
parser.add_argument('--prompt', type=str, required=True,
|
|||
|
|
help='Начальный текст для генерации')
|
|||
|
|
|
|||
|
|
# Параметры модели (должны соответствовать обучению)
|
|||
|
|
parser.add_argument('--seq-len', type=int, default=64,
|
|||
|
|
help='Макс. длина последовательности (как при обучении)')
|
|||
|
|
parser.add_argument('--emb-size', type=int, default=64,
|
|||
|
|
help='Размер эмбеддингов (как при обучении)')
|
|||
|
|
parser.add_argument('--num-heads', type=int, default=4,
|
|||
|
|
help='Количество голов внимания (как при обучении)')
|
|||
|
|
parser.add_argument('--head-size', type=int, default=16,
|
|||
|
|
help='Размер головы внимания (как при обучении)')
|
|||
|
|
parser.add_argument('--num-layers', type=int, default=2,
|
|||
|
|
help='Количество слоёв (как при обучении)')
|
|||
|
|
parser.add_argument('--dropout', type=float, default=0.1,
|
|||
|
|
help='Dropout (как при обучении)')
|
|||
|
|
|
|||
|
|
# Параметры генерации
|
|||
|
|
parser.add_argument('--length', type=int, default=50,
|
|||
|
|
help='Количество генерируемых токенов')
|
|||
|
|
parser.add_argument('--temperature', type=float, default=0.7,
|
|||
|
|
help='Температура сэмплинга (0.1-1.0)')
|
|||
|
|
|
|||
|
|
args = parser.parse_args()
|
|||
|
|
|
|||
|
|
# Загрузка
|
|||
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
print(f"Используется устройство: {device}")
|
|||
|
|
|
|||
|
|
tokenizer = SimpleBPE.load(args.tokenizer)
|
|||
|
|
print(f"Загружен токенизатор (vocab_size={tokenizer.vocab_size})")
|
|||
|
|
|
|||
|
|
# Инициализация модели
|
|||
|
|
model = GPT(
|
|||
|
|
vocab_size=tokenizer.vocab_size,
|
|||
|
|
max_seq_len=args.seq_len,
|
|||
|
|
emb_size=args.emb_size,
|
|||
|
|
num_heads=args.num_heads,
|
|||
|
|
head_size=args.head_size,
|
|||
|
|
num_layers=args.num_layers,
|
|||
|
|
dropout=args.dropout,
|
|||
|
|
device=device
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
model.load_state_dict(torch.load(args.model, map_location=device))
|
|||
|
|
model.eval()
|
|||
|
|
print(f"Загружена модель с {sum(p.numel() for p in model.parameters()):,} параметрами")
|
|||
|
|
|
|||
|
|
# Генерация
|
|||
|
|
print(f"\nГенерация текста для промта: '{args.prompt}'")
|
|||
|
|
tokens = tokenizer.encode(args.prompt)
|
|||
|
|
print(f"Токены промта: {tokens}")
|
|||
|
|
|
|||
|
|
output = model.generate(
|
|||
|
|
x=torch.tensor([tokens], device=device),
|
|||
|
|
max_new_tokens=args.length,
|
|||
|
|
do_sample=True,
|
|||
|
|
temperature=args.temperature
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("\n=== Результат ===")
|
|||
|
|
print(tokenizer.decode(output[0].tolist()))
|
|||
|
|
|
|||
|
|
if __name__ == '__main__':
|
|||
|
|
main()
|