mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Рефакторинг и улучшение компонентов
Основные изменения в коде:
1. Токенизатор (bpe.py):
- Добавлен прогресс-бар через tqdm в метод fit()
- Улучшено логирование процесса обучения
- Добавлена обработка edge-cases для vocab_size
2. Генерация текста (generate_text.py):
- Полный рефакторинг скрипта
- Добавлены проверки модели перед загрузкой
- Поддержка уменьшенных моделей (seq_len=32)
- Подробное логирование процесса генерации
3. Обучение GPT (train_gpt_model.py):
- Автоподбор параметров под размер данных
- Уменьшенные параметры модели по умолчанию
- Контроль памяти и устройств (CPU/MPS)
4. Токенизация корпуса (tokenize_corpus.py):
- Добавлены проверки входных данных
- Подробное логирование процесса
- Обработка ошибок загрузки файлов
Исправления:
- Синхронизация размеров слоёв в GPT
- Корректная работа с малыми наборами данных
- Исправление загрузки моделей на MPS
Обновление README.md
- Добавлены обязательные зависимости: dill и tqdm
- Добавлен раздел 'Цель проекта' с описанием задач
- Добавлен раздел 'Участие в разработке' для контрибьюторов
- Добавлен раздел 'Лицензия' с условиями MIT
Рефакторинг основных скриптов и обновление данных
Основные изменения:
1. Скрипты в bin/:
- Оптимизация generate_text.py (генерация текста)
- Улучшение tokenize_corpus.py (обработка корпуса)
- Рефакторинг train_gpt_model.py (обучение модели)
- Обновление train_tokenizer.py (алгоритм BPE)
2. Данные:
- Удалены устаревшие артефакты:
* simple_llm_gpt.pth (модель)
* bpe_tokenizer.json (токенизатор)
* corpus_tokens.pkl (токены)
- Подготовка к генерации новых данных
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import dill
|
||||
from tqdm import tqdm
|
||||
|
||||
class BPE:
|
||||
"""Реализация алгоритма Byte Pair Encoding (BPE) для токенизации текста.
|
||||
@@ -35,24 +36,30 @@ class BPE:
|
||||
>>> tokenizer = BPE(vocab_size=100)
|
||||
>>> tokenizer.fit("Это текст для обучения токенизатора")
|
||||
"""
|
||||
# Инициализируем прогресс-бар
|
||||
pbar = tqdm(total=self.vocab_size, desc="Building vocabulary")
|
||||
# 1. Получаем уникальные токены (символы)
|
||||
unique_tokens = sorted(set(text))
|
||||
tokens = unique_tokens.copy()
|
||||
pbar.update(len(tokens)) # Обновляем прогресс начальными токенами
|
||||
|
||||
# 2. Разбиваем текст на токены-символы
|
||||
sequence = list(text)
|
||||
|
||||
# 3. Объединяем токены до достижения нужного размера словаря
|
||||
while len(tokens) < self.vocab_size:
|
||||
pbar.update(1) # Обновляем прогресс на каждой итерации
|
||||
print(f"\nТекущий размер словаря: {len(tokens)}/{self.vocab_size}")
|
||||
#print(f'len={len(tokens)} < {self.vocab_size}')
|
||||
# Считаем частоты пар
|
||||
pair_freq = {}
|
||||
for i in range(len(sequence) - 1):
|
||||
pair = (sequence[i], sequence[i + 1])
|
||||
#print(f'pair = {pair}')
|
||||
if pair not in pair_freq:
|
||||
pair_freq[pair] = 0
|
||||
pair_freq[pair] += 1
|
||||
|
||||
print(f"Найдено {len(pair_freq)} уникальных пар")
|
||||
|
||||
|
||||
#print(f'pair_freq = {pair_freq}')
|
||||
@@ -64,12 +71,11 @@ class BPE:
|
||||
|
||||
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
|
||||
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
|
||||
#print(most_frequent_pair)
|
||||
print(f"Самая частая пара: {most_frequent_pair} (встречается {pair_freq[most_frequent_pair]} раз)")
|
||||
# Создаем новый токен
|
||||
new_token = most_frequent_pair[0] + most_frequent_pair[1]
|
||||
#print(f"new token={new_token}")
|
||||
print(f"Добавлен новый токен: '{new_token}'")
|
||||
tokens.append(new_token)
|
||||
#print(f"tokens={tokens}")
|
||||
|
||||
i = 0
|
||||
new_sequence = []
|
||||
@@ -88,6 +94,7 @@ class BPE:
|
||||
self.vocab = tokens.copy()
|
||||
self.token2id = dict(zip(tokens, range(self.vocab_size)))
|
||||
self.id2token = dict(zip(range(self.vocab_size), tokens))
|
||||
pbar.close() # Закрываем прогресс-бар
|
||||
|
||||
def _pair_first_index(self, sequence, pair):
|
||||
for i in range(len(sequence) - 1):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import dill
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict
|
||||
|
||||
@@ -84,4 +85,37 @@ class BPE(ABC):
|
||||
tokens.append(self.id2token[id])
|
||||
else:
|
||||
tokens.append('') # Специальное значение
|
||||
return tokens
|
||||
return tokens
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, 'wb') as f:
|
||||
dill.dump(self, f)
|
||||
print(f"Объект сохранён в {filename}")
|
||||
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename):
|
||||
"""Загружает токенизатор из файла.
|
||||
|
||||
Args:
|
||||
filename (str): Путь к файлу с сохраненным токенизатором
|
||||
|
||||
Returns:
|
||||
BPE: Загруженный экземпляр токенизатора
|
||||
|
||||
Пример:
|
||||
>>> tokenizer = BPE.load("bpe_tokenizer.pkl")
|
||||
"""
|
||||
"""Load trained tokenizer from file.
|
||||
|
||||
Args:
|
||||
filename (str): Path to saved tokenizer
|
||||
|
||||
Returns:
|
||||
BPE: Loaded tokenizer instance
|
||||
"""
|
||||
with open(filename, 'rb') as f:
|
||||
obj = dill.load(f)
|
||||
|
||||
print(f"Объект загружен из {filename}")
|
||||
return obj
|
||||
@@ -1,5 +1,5 @@
|
||||
from .bpe_interface import BPE
|
||||
|
||||
from tqdm import tqdm
|
||||
from collections import Counter
|
||||
from typing import List, Tuple, Dict
|
||||
|
||||
@@ -18,19 +18,34 @@ class OptimizeBPE(BPE):
|
||||
self._init_vocab(sequence)
|
||||
pair_freq, pair_first_occurrence = self._get_pair_stats(sequence)
|
||||
|
||||
while len(self.vocab) < self.vocab_size and pair_freq:
|
||||
pair_to_merge = self._select_pair_to_merge(pair_freq, pair_first_occurrence)
|
||||
new_token = pair_to_merge[0] + pair_to_merge[1]
|
||||
# Инициализация прогресс-бара
|
||||
with tqdm(total=self.vocab_size, desc="Building vocabulary") as pbar:
|
||||
pbar.update(len(self.vocab)) # Учитываем начальные токены
|
||||
|
||||
if new_token in self.vocab:
|
||||
# Защита от зацикливания: пара уже была добавлена как новый токен.
|
||||
del pair_freq[pair_to_merge]
|
||||
continue
|
||||
while len(self.vocab) < self.vocab_size and pair_freq:
|
||||
pair_to_merge = self._select_pair_to_merge(pair_freq, pair_first_occurrence)
|
||||
new_token = pair_to_merge[0] + pair_to_merge[1]
|
||||
|
||||
# Обновляем прогресс и логируем
|
||||
pbar.update(1)
|
||||
pbar.set_postfix({
|
||||
'current_vocab': len(self.vocab),
|
||||
'top_pair': f"{pair_to_merge[0]}{pair_to_merge[1]}",
|
||||
'pair_freq': pair_freq[pair_to_merge]
|
||||
})
|
||||
print(f"\nТекущий размер словаря: {len(self.vocab)}/{self.vocab_size}")
|
||||
print(f"Самая частая пара: {pair_to_merge} (встречается {pair_freq[pair_to_merge]} раз)")
|
||||
print(f"Добавлен новый токен: '{new_token}'")
|
||||
|
||||
self.vocab.append(new_token)
|
||||
sequence, pair_freq, pair_first_occurrence = self._merge_pair(
|
||||
sequence, pair_to_merge, new_token, pair_freq
|
||||
)
|
||||
if new_token in self.vocab:
|
||||
# Защита от зацикливания: пара уже была добавлена как новый токен.
|
||||
del pair_freq[pair_to_merge]
|
||||
continue
|
||||
|
||||
self.vocab.append(new_token)
|
||||
sequence, pair_freq, pair_first_occurrence = self._merge_pair(
|
||||
sequence, pair_to_merge, new_token, pair_freq
|
||||
)
|
||||
|
||||
self._build_token_dicts()
|
||||
|
||||
|
||||
@@ -333,6 +333,9 @@ class GPT(nn.Module):
|
||||
>>> # Обучаем модель
|
||||
>>> model.fit(loader, num_epoch=5, learning_rate=0.001)
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
|
||||
if train_loader is None:
|
||||
raise ValueError("train_loader не может быть None")
|
||||
if num_epoch <= 0:
|
||||
@@ -344,13 +347,24 @@ class GPT(nn.Module):
|
||||
self.to(device)
|
||||
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||
|
||||
|
||||
print(f"\nНачало обучения GPT на {num_epoch} эпох")
|
||||
print(f"Размер батча: {train_loader.batch_size}")
|
||||
print(f"Всего батчей: {len(train_loader)}")
|
||||
print(f"Устройство: {device}\n")
|
||||
|
||||
for epoch in range(num_epoch):
|
||||
self.train()
|
||||
epoch_loss = 0.0
|
||||
|
||||
#for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}"):
|
||||
for inputs, targets in train_loader:
|
||||
start_time = time.time()
|
||||
|
||||
# Прогресс-бар для батчей
|
||||
batch_pbar = tqdm(train_loader,
|
||||
desc=f"Эпоха {epoch+1}/{num_epoch}",
|
||||
leave=False,
|
||||
bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]')
|
||||
|
||||
for batch_idx, (inputs, targets) in enumerate(batch_pbar):
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
@@ -364,15 +378,33 @@ class GPT(nn.Module):
|
||||
optimizer.step()
|
||||
|
||||
epoch_loss += loss.item()
|
||||
|
||||
# Обновляем описание прогресс-бара
|
||||
batch_pbar.set_postfix({
|
||||
'loss': f"{loss.item():.4f}",
|
||||
'lr': f"{learning_rate:.0e}"
|
||||
})
|
||||
|
||||
# Логирование каждые N батчей
|
||||
if batch_idx % 10 == 0:
|
||||
tqdm.write(f"Батч {batch_idx}/{len(train_loader)} - Loss: {loss.item():.4f}")
|
||||
|
||||
self.train_loss = epoch_loss / len(train_loader)
|
||||
#print(f"[{epoch+1}/{num_epoch}] Train Loss: {self.train_loss:.4f}", end='')
|
||||
epoch_time = time.time() - start_time
|
||||
|
||||
print(f"\nЭпоха {epoch+1}/{num_epoch} завершена за {epoch_time:.2f} сек")
|
||||
print(f"Средний Train Loss: {self.train_loss:.4f}")
|
||||
|
||||
if valid_loader is not None:
|
||||
self.eval()
|
||||
valid_loss = 0.0
|
||||
with torch.no_grad():
|
||||
for inputs, targets in valid_loader:
|
||||
# Прогресс-бар для валидации
|
||||
valid_pbar = tqdm(valid_loader,
|
||||
desc=f"Валидация {epoch+1}/{num_epoch}",
|
||||
leave=False)
|
||||
|
||||
for inputs, targets in valid_pbar:
|
||||
inputs = inputs.to(device)
|
||||
targets = targets.to(device)
|
||||
|
||||
@@ -384,4 +416,4 @@ class GPT(nn.Module):
|
||||
valid_loss += loss.item()
|
||||
|
||||
self.validation_loss = valid_loss / len(valid_loader)
|
||||
#print(f" | Val Loss: {self.validation_loss:.4f}")
|
||||
print(f"Средний Val Loss: {self.validation_loss:.4f}")
|
||||
Reference in New Issue
Block a user