mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
72 lines
3.2 KiB
Python
72 lines
3.2 KiB
Python
|
|
"""
|
|||
|
|
Пример использования GPT модели из simple_llm
|
|||
|
|
|
|||
|
|
1. Инициализация модели
|
|||
|
|
2. Генерация текста
|
|||
|
|
3. Сохранение/загрузка модели
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
from simple_llm.transformer.gpt import GPT
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# Конфигурация модели
|
|||
|
|
config = {
|
|||
|
|
'vocab_size': 10000, # Размер словаря
|
|||
|
|
'max_seq_len': 256, # Макс. длина последовательности
|
|||
|
|
'emb_size': 512, # Размерность эмбеддингов
|
|||
|
|
'num_heads': 8, # Количество голов внимания
|
|||
|
|
'head_size': 64, # Размер каждой головы внимания
|
|||
|
|
'num_layers': 6, # Количество слоев декодера
|
|||
|
|
'dropout': 0.1, # Dropout
|
|||
|
|
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
# 1. Инициализация модели
|
|||
|
|
print("Инициализация GPT модели...")
|
|||
|
|
model = GPT(**config)
|
|||
|
|
print(f"Модель создана на устройстве: {config['device']}")
|
|||
|
|
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
|
|||
|
|
|
|||
|
|
# 2. Пример генерации с токенизатором
|
|||
|
|
try:
|
|||
|
|
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
|||
|
|
print("\nИнициализация токенизатора...")
|
|||
|
|
tokenizer = SimpleBPE()
|
|||
|
|
|
|||
|
|
text = "Пример текста для генерации"
|
|||
|
|
print(f"Исходный текст: '{text}'")
|
|||
|
|
|
|||
|
|
input_ids = tokenizer.encode(text)
|
|||
|
|
print(f"Токенизированный ввод: {input_ids}")
|
|||
|
|
|
|||
|
|
input_seq = torch.tensor([input_ids], device=config['device'])
|
|||
|
|
generated = model.generate(input_seq, max_new_tokens=20)
|
|||
|
|
|
|||
|
|
decoded_text = tokenizer.decode(generated[0].tolist())
|
|||
|
|
print(f"\nСгенерированный текст: '{decoded_text}'")
|
|||
|
|
except ImportError:
|
|||
|
|
print("\nТокенизатор не найден, используется числовая генерация...")
|
|||
|
|
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
|
|||
|
|
print(f"Числовой ввод: {input_seq.tolist()[0]}")
|
|||
|
|
|
|||
|
|
generated = model.generate(input_seq, max_new_tokens=20)
|
|||
|
|
print(f"Числовой вывод: {generated.tolist()[0]}")
|
|||
|
|
|
|||
|
|
# 3. Сохранение и загрузка модели
|
|||
|
|
print("\nТест сохранения/загрузки...")
|
|||
|
|
import tempfile
|
|||
|
|
with tempfile.NamedTemporaryFile() as tmp:
|
|||
|
|
model.save(tmp.name)
|
|||
|
|
print(f"Модель сохранена во временный файл: {tmp.name}")
|
|||
|
|
|
|||
|
|
loaded_model = GPT.load(tmp.name, device=config['device'])
|
|||
|
|
print("Модель успешно загружена")
|
|||
|
|
|
|||
|
|
# Проверка работы загруженной модели
|
|||
|
|
test_output = loaded_model(input_seq)
|
|||
|
|
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|