mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Обновление метода generate в GPT
Основные изменения: 1. Добавлена поддержка различных стратегий генерации: - Жадный поиск (do_sample=False) - Вероятностное сэмплирование (do_sample=True) - Top-k сэмплирование (top_k параметр) - Nucleus (top-p) сэмплирование (top_p параметр) - Температурное сэмплирование (temperature параметр) 2. Добавлена валидация параметров: - Проверка temperature > 0 - Проверка top_k > 0 - Проверка top_p в диапазоне (0, 1] - Запрет одновременного использования top_k и top_p 3. Улучшена документация: - Подробное описание всех параметров - Примеры использования - Примечания о детерминированности - Описание исключений 4. Оптимизация кода: - Эффективное обрезание последовательности - Оптимизированные операции с тензорами - Четкое разделение логики для разных режимов
This commit is contained in:
@@ -1,24 +1,66 @@
|
||||
"""
|
||||
Пример использования GPT модели из simple_llm
|
||||
|
||||
1. Инициализация модели
|
||||
2. Генерация текста
|
||||
3. Сохранение/загрузка модели
|
||||
"""
|
||||
|
||||
import torch
|
||||
import os
|
||||
from simple_llm.transformer.gpt import GPT
|
||||
|
||||
def use_numeric_generation(config, model):
|
||||
"""Функция для числовой генерации"""
|
||||
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
|
||||
print(f"\nЧисловой ввод: {input_seq.tolist()[0]}")
|
||||
|
||||
print("\n=== Режимы генерации ===")
|
||||
|
||||
# 1. Жадная генерация
|
||||
greedy_output = model.generate(input_seq.clone(),
|
||||
max_new_tokens=20,
|
||||
do_sample=False)
|
||||
print("\n1. Жадная генерация (детерминированная):")
|
||||
print(greedy_output.tolist()[0])
|
||||
|
||||
# 2. Сэмплирование с температурой
|
||||
torch.manual_seed(42)
|
||||
temp_output = model.generate(input_seq.clone(),
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
temperature=0.7)
|
||||
print("\n2. Сэмплирование (температура=0.7):")
|
||||
print(temp_output.tolist()[0])
|
||||
|
||||
# 3. Top-k сэмплирование
|
||||
torch.manual_seed(42)
|
||||
topk_output = model.generate(input_seq.clone(),
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
top_k=50)
|
||||
print("\n3. Top-k сэмплирование (k=50):")
|
||||
print(topk_output.tolist()[0])
|
||||
|
||||
# 4. Nucleus (top-p) сэмплирование
|
||||
try:
|
||||
torch.manual_seed(42)
|
||||
topp_output = model.generate(input_seq.clone(),
|
||||
max_new_tokens=20,
|
||||
do_sample=True,
|
||||
top_p=0.9)
|
||||
print("\n4. Nucleus сэмплирование (p=0.9):")
|
||||
print(topp_output.tolist()[0])
|
||||
except Exception as e:
|
||||
print(f"\nОшибка при nucleus сэмплировании: {str(e)}")
|
||||
print("Пропускаем этот режим генерации")
|
||||
|
||||
def main():
|
||||
# Конфигурация модели
|
||||
config = {
|
||||
'vocab_size': 10000, # Размер словаря
|
||||
'max_seq_len': 256, # Макс. длина последовательности
|
||||
'emb_size': 512, # Размерность эмбеддингов
|
||||
'num_heads': 8, # Количество голов внимания
|
||||
'head_size': 64, # Размер каждой головы внимания
|
||||
'num_layers': 6, # Количество слоев декодера
|
||||
'dropout': 0.1, # Dropout
|
||||
'vocab_size': 10000,
|
||||
'max_seq_len': 256,
|
||||
'emb_size': 512,
|
||||
'num_heads': 8,
|
||||
'head_size': 64,
|
||||
'num_layers': 6,
|
||||
'dropout': 0.1,
|
||||
'device': 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
}
|
||||
|
||||
@@ -28,30 +70,9 @@ def main():
|
||||
print(f"Модель создана на устройстве: {config['device']}")
|
||||
print(f"Количество параметров: {sum(p.numel() for p in model.parameters()):,}")
|
||||
|
||||
# 2. Пример генерации с токенизатором
|
||||
try:
|
||||
from simple_llm.tokenizer.simple_bpe import SimpleBPE
|
||||
print("\nИнициализация токенизатора...")
|
||||
tokenizer = SimpleBPE()
|
||||
|
||||
text = "Пример текста для генерации"
|
||||
print(f"Исходный текст: '{text}'")
|
||||
|
||||
input_ids = tokenizer.encode(text)
|
||||
print(f"Токенизированный ввод: {input_ids}")
|
||||
|
||||
input_seq = torch.tensor([input_ids], device=config['device'])
|
||||
generated = model.generate(input_seq, max_new_tokens=20)
|
||||
|
||||
decoded_text = tokenizer.decode(generated[0].tolist())
|
||||
print(f"\nСгенерированный текст: '{decoded_text}'")
|
||||
except ImportError:
|
||||
print("\nТокенизатор не найден, используется числовая генерация...")
|
||||
input_seq = torch.randint(0, config['vocab_size'], (1, 10)).to(config['device'])
|
||||
print(f"Числовой ввод: {input_seq.tolist()[0]}")
|
||||
|
||||
generated = model.generate(input_seq, max_new_tokens=20)
|
||||
print(f"Числовой вывод: {generated.tolist()[0]}")
|
||||
# 2. Пример генерации
|
||||
print("\nИспользуется числовая генерация...")
|
||||
use_numeric_generation(config, model)
|
||||
|
||||
# 3. Сохранение и загрузка модели
|
||||
print("\nТест сохранения/загрузки...")
|
||||
@@ -63,9 +84,8 @@ def main():
|
||||
loaded_model = GPT.load(tmp.name, device=config['device'])
|
||||
print("Модель успешно загружена")
|
||||
|
||||
# Проверка работы загруженной модели
|
||||
test_output = loaded_model(input_seq)
|
||||
test_output = loaded_model(torch.randint(0, config['vocab_size'], (1, 5)).to(config['device']))
|
||||
print(f"Тест загруженной модели - выходная форма: {test_output.shape}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
Reference in New Issue
Block a user