Files
simple-llm/example/example_gpt.py
Sergey Penkovsky ae87faddc2 feat: добавление реализации модели GPT
Основные изменения:
- Реализован основной класс GPT в simple_llm/transformer/gpt.py:
  * Токенные и позиционные эмбеддинги
  * Многоголовое внимание
  * Полносвязные слои
  * Нормализация слоев
  * Поддержка dropout

- Добавлен пример использования в example/example_gpt.py:
  * Инициализация модели
  * Генерация текста
  * Сохранение/загрузка модели

- Написаны тесты:
  * Базовый функционал модели
  * Операции сохранения/загрузки
  * Проверка размерностей ввода/вывода

- Добавлена документация на русском:
  * Обзор архитектуры
  * Процесс обучения
  * Примеры использования

- Обновлен README.md с информацией о GPT
2025-07-22 06:26:23 +03:00

72 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Пример использования GPT модели из simple_llm
1. Инициализация модели
2. Генерация текста
3. Сохранение/загрузка модели
"""
import torch
from simple_llm.transformer.gpt import GPT
def main():
# Конфигурация модели
config = {
'vocab_size': 10000, # Размер словаря
'max_seq_len': 256, # Макс. длина последовательности
'emb_size': 512, # Размерность эмбеддингов
'num_heads': 8, # Количество голов внимания
'head_size': 64, # Размер каждой головы внимания
'num_layers': 6, # Количество слоев декодера
'dropout': 0.1, # Dropout
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
# 1. Инициализация модели
print("Инициализация GPT модели...")
model = GPT(**config)
print(f"Модель создана на устройстве: {config['device']}")
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
# 2. Пример генерации с токенизатором
try:
from simple_llm.tokenizer.simple_bpe import SimpleBPE
print("\nИнициализация токенизатора...")
tokenizer = SimpleBPE()
text = "Пример текста для генерации"
print(f"Исходный текст: '{text}'")
input_ids = tokenizer.encode(text)
print(f"Токенизированный ввод: {input_ids}")
input_seq = torch.tensor([input_ids], device=config['device'])
generated = model.generate(input_seq, max_new_tokens=20)
decoded_text = tokenizer.decode(generated[0].tolist())
print(f"\nСгенерированный текст: '{decoded_text}'")
except ImportError:
print("\nТокенизатор не найден, используется числовая генерация...")
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
print(f"Числовой ввод: {input_seq.tolist()[0]}")
generated = model.generate(input_seq, max_new_tokens=20)
print(f"Числовой вывод: {generated.tolist()[0]}")
# 3. Сохранение и загрузка модели
print("\nТест сохранения/загрузки...")
import tempfile
with tempfile.NamedTemporaryFile() as tmp:
model.save(tmp.name)
print(f"Модель сохранена во временный файл: {tmp.name}")
loaded_model = GPT.load(tmp.name, device=config['device'])
print("Модель успешно загружена")
# Проверка работы загруженной модели
test_output = loaded_model(input_seq)
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
if __name__ == "__main__":
main()