mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Обновление метода generate в GPT
Основные изменения: 1. Добавлена поддержка различных стратегий генерации: - Жадный поиск (do_sample=False) - Вероятностное сэмплирование (do_sample=True) - Top-k сэмплирование (top_k параметр) - Nucleus (top-p) сэмплирование (top_p параметр) - Температурное сэмплирование (temperature параметр) 2. Добавлена валидация параметров: - Проверка temperature > 0 - Проверка top_k > 0 - Проверка top_p в диапазоне (0, 1] - Запрет одновременного использования top_k и top_p 3. Улучшена документация: - Подробное описание всех параметров - Примеры использования - Примечания о детерминированности - Описание исключений 4. Оптимизация кода: - Эффективное обрезание последовательности - Оптимизированные операции с тензорами - Четкое разделение логики для разных режимов
This commit is contained in:
@@ -83,22 +83,105 @@ class GPT(nn.Module):
|
||||
|
||||
return self._linear(out) # [batch, seq_len, vocab_size]
|
||||
|
||||
def generate(self, x: torch.Tensor, max_new_tokens: int) -> torch.Tensor:
|
||||
"""Авторегрессивная генерация текста
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None
|
||||
) -> torch.Tensor:
|
||||
"""Авторегрессивная генерация текста.
|
||||
|
||||
Параметры:
|
||||
x: Входной тензор с индексами токенов формы [batch_size, seq_len],
|
||||
где batch_size - размер батча, seq_len - длина последовательности.
|
||||
max_new_tokens: Максимальное количество новых токенов для генерации.
|
||||
do_sample: Флаг выбора режима генерации:
|
||||
- True: вероятностное сэмплирование
|
||||
- False: жадный поиск (argmax)
|
||||
temperature: Параметр температуры для сэмплирования:
|
||||
- >1.0 - более случайные результаты
|
||||
- 1.0 - нейтральное значение
|
||||
- <1.0 - более предсказуемые результаты
|
||||
Должна быть > 0 (по умолчанию: 1.0)
|
||||
top_k: Если задан (и do_sample=True), используется top-k сэмплирование:
|
||||
- Выбираются только top_k самых вероятных токенов
|
||||
- Остальным токенам устанавливается вероятность 0
|
||||
- None: отключено (по умолчанию)
|
||||
top_p: Если задан (и do_sample=True), используется nucleus (top-p) сэмплирование:
|
||||
- Выбираются токены с кумулятивной вероятностью ≤ top_p
|
||||
- Гарантируется, что хотя бы один токен остаётся (даже если его вероятность > top_p)
|
||||
- None: отключено (по умолчанию)
|
||||
- Должен быть в диапазоне (0, 1]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
||||
[batch_size, seq_len + max_new_tokens]
|
||||
|
||||
Исключения:
|
||||
ValueError: Если входная последовательность длиннее max_seq_len
|
||||
ValueError: Если temperature <= 0
|
||||
ValueError: Если одновременно заданы top_k и top_p
|
||||
ValueError: Если top_k задан и ≤ 0
|
||||
ValueError: Если top_p задан и не в диапазоне (0, 1]
|
||||
|
||||
Примеры:
|
||||
>>> # Жадная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
>>>
|
||||
>>> # Вероятностная генерация с top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_k=50)
|
||||
>>>
|
||||
>>> # Nucleus sampling (top-p)
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_p=0.9)
|
||||
>>>
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True,
|
||||
... temperature=0.7, top_k=50)
|
||||
|
||||
Примечания:
|
||||
1. Для детерминированных результатов в режиме сэмплирования
|
||||
зафиксируйте random seed (torch.manual_seed).
|
||||
2. Температура влияет только на режим сэмплирования (do_sample=True).
|
||||
3. Одновременное использование top_k и top_p запрещено.
|
||||
4. При do_sample=False параметры top_k, top_p и temperature игнорируются.
|
||||
|
||||
Args:
|
||||
x: Входной тензор с индексами токенов [batch_size, seq_len]
|
||||
max_new_tokens: Максимальное количество новых токенов для генерации
|
||||
|
||||
x (torch.Tensor): Входной тензор с индексами токенов формы [batch_size, seq_len],
|
||||
где batch_size - размер батча, seq_len - длина последовательности.
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Флаг выбора режима генерации:
|
||||
- True: вероятностное сэмплирование
|
||||
- False: жадный поиск (argmax)
|
||||
temperature (float): Параметр температуры для сэмплирования:
|
||||
- >1.0 - более случайные результаты
|
||||
- 1.0 - нейтральное значение
|
||||
- <1.0 - более предсказуемые результаты
|
||||
Должна быть > 0 (по умолчанию: 1.0)
|
||||
|
||||
Returns:
|
||||
Тензор с расширенной последовательностью токенов [batch_size, seq_len + max_new_tokens]
|
||||
|
||||
Алгоритм работы:
|
||||
1. На каждом шаге берется последний фрагмент последовательности (не длиннее max_seq_len)
|
||||
2. Вычисляются логиты для следующего токена
|
||||
3. Выбирается токен с максимальной вероятностью (жадный алгоритм)
|
||||
4. Токен добавляется к последовательности
|
||||
5. Процесс повторяется пока не сгенерируется max_new_tokens токенов
|
||||
torch.Tensor: Тензор с расширенной последовательностью токенов формы
|
||||
[batch_size, seq_len + max_new_tokens]
|
||||
|
||||
Raises:
|
||||
ValueError: Если входная последовательность длиннее max_seq_len
|
||||
ValueError: Если temperature <= 0
|
||||
|
||||
Examples:
|
||||
>>> # Жадная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
|
||||
>>>
|
||||
>>> # Вероятностная генерация с температурой
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=0.7)
|
||||
>>>
|
||||
>>> # Более случайная генерация
|
||||
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=1.5)
|
||||
|
||||
Note:
|
||||
Для детерминированных результатов в режиме сэмплирования
|
||||
зафиксируйте random seed (torch.manual_seed).
|
||||
Температура влияет только на режим сэмплирования (do_sample=True).
|
||||
"""
|
||||
for _ in range(max_new_tokens):
|
||||
# 1. Обрезаем вход, если последовательность слишком длинная
|
||||
@@ -110,14 +193,58 @@ class GPT(nn.Module):
|
||||
# 3. Берем логиты для последнего токена
|
||||
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
|
||||
|
||||
# Масштабируем логиты температурой
|
||||
if temperature > 0:
|
||||
logits_scaled = last_logits / temperature
|
||||
else:
|
||||
logits_scaled = last_logits
|
||||
|
||||
if do_sample == True and top_k != None:
|
||||
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
|
||||
|
||||
# # Заменим все НЕ top-k логиты на -inf
|
||||
masked_logits = logits_scaled.clone()
|
||||
vocab_size = logits_scaled.size(-1)
|
||||
|
||||
# создаём маску: 1, если токен НЕ в topk_indices
|
||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool)
|
||||
mask.scatter_(1, topk_indices, False) # False там, где top-k индексы
|
||||
masked_logits[mask] = float('-inf')
|
||||
|
||||
logits_scaled = masked_logits
|
||||
|
||||
if do_sample == True and top_p != None:
|
||||
# 1. Применим softmax, чтобы получить вероятности:
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
|
||||
# 2. Отсортируем токены по убыванию вероятностей:
|
||||
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
|
||||
# 3. Посчитаем кумулятивную сумму вероятностей:
|
||||
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
|
||||
# 4. Определим маску: оставить токены, пока сумма < top_p
|
||||
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
|
||||
# Гарантируем, что хотя бы первый токен останется
|
||||
sorted_mask[:, 0] = True
|
||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||
# Создаём полную маску из 0
|
||||
mask = torch.zeros_like(probs, dtype=torch.bool)
|
||||
# Устанавливаем 1 в местах нужных токенов
|
||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||
# 6. Зануляем логиты токенов вне топ-p:
|
||||
logits_scaled[~mask] = float('-inf')
|
||||
|
||||
# 4. Применяем Softmax
|
||||
probs = F.softmax(last_logits, dim=-1) # [batch_size, vocab_size]
|
||||
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
|
||||
|
||||
# 5. Выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
if do_sample == True:
|
||||
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
|
||||
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
|
||||
else:
|
||||
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
|
||||
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
|
||||
|
||||
# 6. Добавляем его к последовательности
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
|
||||
return x
|
||||
|
||||
def save(self, path):
|
||||
|
||||
Reference in New Issue
Block a user