Обновление метода generate в GPT

Основные изменения:
1. Добавлена поддержка различных стратегий генерации:
   - Жадный поиск (do_sample=False)
   - Вероятностное сэмплирование (do_sample=True)
   - Top-k сэмплирование (top_k параметр)
   - Nucleus (top-p) сэмплирование (top_p параметр)
   - Температурное сэмплирование (temperature параметр)

2. Добавлена валидация параметров:
   - Проверка temperature > 0
   - Проверка top_k > 0
   - Проверка top_p в диапазоне (0, 1]
   - Запрет одновременного использования top_k и top_p

3. Улучшена документация:
   - Подробное описание всех параметров
   - Примеры использования
   - Примечания о детерминированности
   - Описание исключений

4. Оптимизация кода:
   - Эффективное обрезание последовательности
   - Оптимизированные операции с тензорами
   - Четкое разделение логики для разных режимов
This commit is contained in:
Sergey Penkovsky
2025-07-22 10:53:57 +03:00
parent ae87faddc2
commit 5765eb3bd3
5 changed files with 441 additions and 219 deletions

View File

@@ -1,24 +1,66 @@
"""
Пример использования GPT модели из simple_llm
1. Инициализация модели
2. Генерация текста
3. Сохранение/загрузка модели
"""
import torch
import os
from simple_llm.transformer.gpt import GPT
def use_numeric_generation(config, model):
"""Функция для числовой генерации"""
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
print(f"\nЧисловой ввод: {input_seq.tolist()[0]}")
print("\n=== Режимы генерации ===")
# 1. Жадная генерация
greedy_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=False)
print("\n1. Жадная генерация (детерминированная):")
print(greedy_output.tolist()[0])
# 2. Сэмплирование с температурой
torch.manual_seed(42)
temp_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
temperature=0.7)
print("\n2. Сэмплирование (температура=0.7):")
print(temp_output.tolist()[0])
# 3. Top-k сэмплирование
torch.manual_seed(42)
topk_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
top_k=50)
print("\n3. Top-k сэмплирование (k=50):")
print(topk_output.tolist()[0])
# 4. Nucleus (top-p) сэмплирование
try:
torch.manual_seed(42)
topp_output = model.generate(input_seq.clone(),
max_new_tokens=20,
do_sample=True,
top_p=0.9)
print("\n4. Nucleus сэмплирование (p=0.9):")
print(topp_output.tolist()[0])
except Exception as e:
print(f"\nОшибка при nucleus сэмплировании: {str(e)}")
print("Пропускаем этот режим генерации")
def main():
# Конфигурация модели
config = {
'vocab_size': 10000, # Размер словаря
'max_seq_len': 256, # Макс. длина последовательности
'emb_size': 512, # Размерность эмбеддингов
'num_heads': 8, # Количество голов внимания
'head_size': 64, # Размер каждой головы внимания
'num_layers': 6, # Количество слоев декодера
'dropout': 0.1, # Dropout
'vocab_size': 10000,
'max_seq_len': 256,
'emb_size': 512,
'num_heads': 8,
'head_size': 64,
'num_layers': 6,
'dropout': 0.1,
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}
@@ -28,30 +70,9 @@ def main():
print(f"Модель создана на устройстве: {config['device']}")
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
# 2. Пример генерации с токенизатором
try:
from simple_llm.tokenizer.simple_bpe import SimpleBPE
print("\nИнициализация токенизатора...")
tokenizer = SimpleBPE()
text = "Пример текста для генерации"
print(f"Исходный текст: '{text}'")
input_ids = tokenizer.encode(text)
print(f"Токенизированный ввод: {input_ids}")
input_seq = torch.tensor([input_ids], device=config['device'])
generated = model.generate(input_seq, max_new_tokens=20)
decoded_text = tokenizer.decode(generated[0].tolist())
print(f"\nСгенерированный текст: '{decoded_text}'")
except ImportError:
print("\nТокенизатор не найден, используется числовая генерация...")
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
print(f"Числовой ввод: {input_seq.tolist()[0]}")
generated = model.generate(input_seq, max_new_tokens=20)
print(f"Числовой вывод: {generated.tolist()[0]}")
# 2. Пример генерации
print("\nИспользуется числовая генерация...")
use_numeric_generation(config, model)
# 3. Сохранение и загрузка модели
print("\nТест сохранения/загрузки...")
@@ -63,9 +84,8 @@ def main():
loaded_model = GPT.load(tmp.name, device=config['device'])
print("Модель успешно загружена")
# Проверка работы загруженной модели
test_output = loaded_model(input_seq)
test_output = loaded_model(torch.randint(0, config['vocab_size'], (1, 5)).to(config['device']))
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
if __name__ == "__main__":
main()
main()