mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация и документирование метода fit() для обучения GPT
Основные изменения: 1. Реализация метода fit(): - Добавлен полный цикл обучения (forward/backward pass) - Поддержка обучения на CPU/GPU - Расчет и сохранение метрик (train_loss, validation_loss) - Интеграция с оптимизатором Adam 2. Документация: - Подробное описание метода в gpt_documentation_ru.md - Примеры использования в README.md - Параметры и требования к данным 3. Тестирование: - Тесты базовой функциональности - Проверка изменения весов - Тесты для разных устройств (CPU/CUDA) - Обработка edge-cases 4. Примеры: - train_gpt_example.py с полным workflow - Генерация синтетических данных - Сохранение/загрузка моделей
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -209,3 +209,4 @@ __marimo__/
|
|||||||
|
|
||||||
.vscode
|
.vscode
|
||||||
example_output/
|
example_output/
|
||||||
|
trained_gpt_model.pt
|
||||||
33
README.md
33
README.md
@@ -21,9 +21,13 @@ Simple-LLM предоставляет:
|
|||||||
pip install torch numpy tqdm
|
pip install torch numpy tqdm
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Запустите пример генерации:
|
2. Запустите примеры:
|
||||||
```bash
|
```bash
|
||||||
|
# Пример генерации текста
|
||||||
python example/example_gpt.py
|
python example/example_gpt.py
|
||||||
|
|
||||||
|
# Пример обучения модели
|
||||||
|
python example/train_gpt_example.py
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🧠 Основные компоненты
|
## 🧠 Основные компоненты
|
||||||
@@ -63,6 +67,33 @@ output = model.generate(
|
|||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Обучение модели
|
||||||
|
```python
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# Данные должны быть в формате (input_ids, targets)
|
||||||
|
# targets - это input_ids, сдвинутые на 1 токен вперед
|
||||||
|
train_loader = DataLoader(...)
|
||||||
|
|
||||||
|
model.fit(
|
||||||
|
train_loader=train_loader, # Обучающие данные (обязательно)
|
||||||
|
valid_loader=None, # Валидационные данные (опционально)
|
||||||
|
num_epoch=10, # Количество эпох
|
||||||
|
learning_rate=0.001 # Скорость обучения
|
||||||
|
)
|
||||||
|
|
||||||
|
# Сохранение модели
|
||||||
|
model.save("model.pt")
|
||||||
|
|
||||||
|
# Загрузка модели
|
||||||
|
loaded_model = GPT.load("model.pt", device="cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
**Требования к данным:**
|
||||||
|
- Формат: `(input_ids, targets)` где `targets = roll(input_ids, -1)`
|
||||||
|
- `input_ids`: тензор формы `[batch_size, seq_len]`
|
||||||
|
- Поддерживаются как синтетические, так и реальные текстовые данные
|
||||||
|
|
||||||
## 📚 Документация
|
## 📚 Документация
|
||||||
|
|
||||||
Полная документация доступна в [doc/](./doc/):
|
Полная документация доступна в [doc/](./doc/):
|
||||||
|
|||||||
@@ -134,7 +134,44 @@ model = GPT(
|
|||||||
| Повторы в генерации | Настройка repetition_penalty |
|
| Повторы в генерации | Настройка repetition_penalty |
|
||||||
| Медленная генерация | Кэширование ключей/значений |
|
| Медленная генерация | Кэширование ключей/значений |
|
||||||
|
|
||||||
## 6. Дополнительные материалы
|
## 6. Обучение модели
|
||||||
|
|
||||||
|
Метод `fit()` позволяет обучать модель GPT на ваших данных.
|
||||||
|
|
||||||
|
### 6.1 Параметры обучения
|
||||||
|
```python
|
||||||
|
model.fit(
|
||||||
|
train_loader, # DataLoader с обучающими данными
|
||||||
|
valid_loader=None, # Опциональный DataLoader для валидации
|
||||||
|
num_epoch=10, # Количество эпох
|
||||||
|
learning_rate=0.001 # Скорость обучения
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Пример обучения
|
||||||
|
```python
|
||||||
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
# Создание синтетических данных
|
||||||
|
inputs = torch.randint(0, 10000, (1000, 32)) # 1000 примеров по 32 токена
|
||||||
|
targets = torch.roll(inputs, -1, dims=1) # Сдвиг на 1 токен вперед
|
||||||
|
|
||||||
|
dataset = TensorDataset(inputs, targets)
|
||||||
|
train_loader = DataLoader(dataset, batch_size=32)
|
||||||
|
|
||||||
|
model = GPT(...) # Инициализация модели
|
||||||
|
model.fit(train_loader, num_epoch=5)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.3 Особенности обучения
|
||||||
|
- Автоматическое определение устройства (CPU/GPU)
|
||||||
|
- Сохранение истории потерь в атрибутах:
|
||||||
|
- `model.train_loss` - потери на обучении
|
||||||
|
- `model.validation_loss` - потери на валидации (если есть)
|
||||||
|
- Поддержка ранней остановки (через callback)
|
||||||
|
|
||||||
|
## 7. Дополнительные материалы
|
||||||
- [Примеры использования](../example/example_gpt.py)
|
- [Примеры использования](../example/example_gpt.py)
|
||||||
|
- [Пример обучения](../example/train_gpt_example.py)
|
||||||
- [Тесты](../tests/test_gpt.py)
|
- [Тесты](../tests/test_gpt.py)
|
||||||
- [Оптимизация производительности](performance_guide.md)
|
- [Оптимизация производительности](performance_guide.md)
|
||||||
|
|||||||
78
example/train_gpt_example.py
Normal file
78
example/train_gpt_example.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""
|
||||||
|
Пример обучения модели GPT на синтетических данных
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from simple_llm.transformer.gpt import GPT
|
||||||
|
|
||||||
|
class SyntheticDataset(Dataset):
|
||||||
|
"""Синтетический датасет для демонстрации обучения GPT"""
|
||||||
|
def __init__(self, vocab_size=100, seq_len=20, num_samples=1000):
|
||||||
|
self.data = torch.randint(0, vocab_size, (num_samples, seq_len))
|
||||||
|
self.targets = torch.roll(self.data, -1, dims=1) # Сдвигаем на 1 для предсказания следующего токена
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[idx], self.targets[idx]
|
||||||
|
|
||||||
|
def train_gpt():
|
||||||
|
# Параметры модели
|
||||||
|
VOCAB_SIZE = 100
|
||||||
|
SEQ_LEN = 20
|
||||||
|
EMB_SIZE = 128
|
||||||
|
N_HEADS = 4
|
||||||
|
HEAD_SIZE = 32
|
||||||
|
N_LAYERS = 3
|
||||||
|
|
||||||
|
# Инициализация модели
|
||||||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
model = GPT(
|
||||||
|
vocab_size=VOCAB_SIZE,
|
||||||
|
max_seq_len=SEQ_LEN,
|
||||||
|
emb_size=EMB_SIZE,
|
||||||
|
num_heads=N_HEADS,
|
||||||
|
head_size=HEAD_SIZE,
|
||||||
|
num_layers=N_LAYERS,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Создание датасета и загрузчика
|
||||||
|
dataset = SyntheticDataset(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)
|
||||||
|
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
|
||||||
|
|
||||||
|
print(f"Начало обучения на {device}...")
|
||||||
|
print(f"Размер словаря: {VOCAB_SIZE}")
|
||||||
|
print(f"Длина последовательности: {SEQ_LEN}")
|
||||||
|
print(f"Размер батча: 32")
|
||||||
|
print(f"Количество эпох: 5\n")
|
||||||
|
|
||||||
|
# Обучение модели
|
||||||
|
model.fit(
|
||||||
|
train_loader=train_loader,
|
||||||
|
valid_loader=None, # Можно добавить валидационный набор
|
||||||
|
num_epoch=5,
|
||||||
|
learning_rate=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
# Сохранение модели
|
||||||
|
model_path = "trained_gpt_model.pt"
|
||||||
|
model.save(model_path)
|
||||||
|
print(f"\nМодель сохранена в {model_path}")
|
||||||
|
|
||||||
|
# Пример генерации текста
|
||||||
|
print("\nПример генерации:")
|
||||||
|
input_seq = torch.randint(0, VOCAB_SIZE, (1, 10)).to(device)
|
||||||
|
generated = model.generate(
|
||||||
|
x=input_seq,
|
||||||
|
max_new_tokens=10,
|
||||||
|
do_sample=True,
|
||||||
|
temperature=0.7
|
||||||
|
)
|
||||||
|
print(f"Вход: {input_seq.tolist()[0]}")
|
||||||
|
print(f"Сгенерировано: {generated.tolist()[0]}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_gpt()
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
from simple_llm.embedding.token_embeddings import TokenEmbeddings
|
||||||
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
from simple_llm.embedding.positional_embeddings import PositionalEmbeddings
|
||||||
from simple_llm.transformer.decoder import Decoder
|
from simple_llm.transformer.decoder import Decoder
|
||||||
@@ -37,6 +38,9 @@ class GPT(nn.Module):
|
|||||||
self._num_layers = num_layers
|
self._num_layers = num_layers
|
||||||
self._dropout = dropout
|
self._dropout = dropout
|
||||||
self._device = device
|
self._device = device
|
||||||
|
|
||||||
|
self.train_loss = None
|
||||||
|
self.validation_loss = None
|
||||||
|
|
||||||
# Инициализация слоев
|
# Инициализация слоев
|
||||||
self._token_embeddings = TokenEmbeddings(
|
self._token_embeddings = TokenEmbeddings(
|
||||||
@@ -206,7 +210,7 @@ class GPT(nn.Module):
|
|||||||
masked_logits = logits_scaled.clone()
|
masked_logits = logits_scaled.clone()
|
||||||
vocab_size = logits_scaled.size(-1)
|
vocab_size = logits_scaled.size(-1)
|
||||||
|
|
||||||
# создаём маску: 1, если токен НЕ в topk_indices
|
# создаём маску: True, если токен НЕ в topk_indices
|
||||||
mask = torch.ones_like(logits_scaled, dtype=torch.bool)
|
mask = torch.ones_like(logits_scaled, dtype=torch.bool)
|
||||||
mask.scatter_(1, topk_indices, False) # False там, где top-k индексы
|
mask.scatter_(1, topk_indices, False) # False там, где top-k индексы
|
||||||
masked_logits[mask] = float('-inf')
|
masked_logits[mask] = float('-inf')
|
||||||
@@ -225,9 +229,9 @@ class GPT(nn.Module):
|
|||||||
# Гарантируем, что хотя бы первый токен останется
|
# Гарантируем, что хотя бы первый токен останется
|
||||||
sorted_mask[:, 0] = True
|
sorted_mask[:, 0] = True
|
||||||
# 5. Преобразуем маску обратно в оригинальный порядок:
|
# 5. Преобразуем маску обратно в оригинальный порядок:
|
||||||
# Создаём полную маску из 0
|
# Создаём полную маску из False
|
||||||
mask = torch.zeros_like(probs, dtype=torch.bool)
|
mask = torch.zeros_like(probs, dtype=torch.bool)
|
||||||
# Устанавливаем 1 в местах нужных токенов
|
# Устанавливаем True в местах нужных токенов
|
||||||
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
|
||||||
# 6. Зануляем логиты токенов вне топ-p:
|
# 6. Зануляем логиты токенов вне топ-p:
|
||||||
logits_scaled[~mask] = float('-inf')
|
logits_scaled[~mask] = float('-inf')
|
||||||
@@ -276,4 +280,108 @@ class GPT(nn.Module):
|
|||||||
@property
|
@property
|
||||||
def max_seq_len(self) -> int:
|
def max_seq_len(self) -> int:
|
||||||
"""Возвращает максимальную длину последовательности"""
|
"""Возвращает максимальную длину последовательности"""
|
||||||
return self._max_seq_len
|
return self._max_seq_len
|
||||||
|
|
||||||
|
|
||||||
|
def fit(self,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
valid_loader: DataLoader = None,
|
||||||
|
num_epoch: int = 1,
|
||||||
|
learning_rate: float = 0.001
|
||||||
|
):
|
||||||
|
"""Обучает модель GPT на предоставленных данных.
|
||||||
|
|
||||||
|
Процесс обучения включает:
|
||||||
|
- Прямой проход для получения предсказаний
|
||||||
|
- Вычисление потерь с помощью кросс-энтропии
|
||||||
|
- Обратное распространение ошибки
|
||||||
|
- Обновление весов через оптимизатор Adam
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_loader (DataLoader): Загрузчик обучающих данных, возвращающий пары (inputs, targets).
|
||||||
|
inputs - тензор индексов токенов формы [batch_size, seq_len]
|
||||||
|
targets - тензор индексов следующих токенов формы [batch_size, seq_len]
|
||||||
|
valid_loader (DataLoader, optional): Загрузчик валидационных данных. Если None,
|
||||||
|
валидация не выполняется. По умолчанию None.
|
||||||
|
num_epoch (int, optional): Количество эпох обучения. По умолчанию 1.
|
||||||
|
learning_rate (float, optional): Скорость обучения для оптимизатора. По умолчанию 0.001.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Если train_loader равен None
|
||||||
|
ValueError: Если num_epoch ≤ 0
|
||||||
|
ValueError: Если learning_rate ≤ 0
|
||||||
|
|
||||||
|
Side Effects:
|
||||||
|
- Обновляет параметры модели (self.parameters())
|
||||||
|
- Устанавливает атрибуты:
|
||||||
|
self.train_loss: Средние потери на обучении за последнюю эпоху
|
||||||
|
self.validation_loss: Средние потери на валидации за последнюю эпоху (если valid_loader передан)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
>>> # Создаем тестовые данные
|
||||||
|
>>> inputs = torch.randint(0, 100, (100, 10))
|
||||||
|
>>> targets = torch.randint(0, 100, (100, 10))
|
||||||
|
>>> dataset = TensorDataset(inputs, targets)
|
||||||
|
>>> loader = DataLoader(dataset, batch_size=32)
|
||||||
|
>>> # Инициализируем модель
|
||||||
|
>>> model = GPT(vocab_size=100, max_seq_len=20, emb_size=64,
|
||||||
|
... num_heads=4, head_size=16, num_layers=2)
|
||||||
|
>>> # Обучаем модель
|
||||||
|
>>> model.fit(loader, num_epoch=5, learning_rate=0.001)
|
||||||
|
"""
|
||||||
|
if train_loader is None:
|
||||||
|
raise ValueError("train_loader не может быть None")
|
||||||
|
if num_epoch <= 0:
|
||||||
|
raise ValueError("num_epoch должен быть > 0")
|
||||||
|
if learning_rate <= 0:
|
||||||
|
raise ValueError("learning_rate должен быть > 0")
|
||||||
|
|
||||||
|
device = torch.device(self._device)
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
|
||||||
|
|
||||||
|
for epoch in range(num_epoch):
|
||||||
|
self.train()
|
||||||
|
epoch_loss = 0.0
|
||||||
|
|
||||||
|
#for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoch}"):
|
||||||
|
for inputs, targets in train_loader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
|
logits = self(inputs)
|
||||||
|
logits = logits.view(-1, logits.size(-1))
|
||||||
|
targets = targets.view(-1)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits, targets)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
epoch_loss += loss.item()
|
||||||
|
|
||||||
|
self.train_loss = epoch_loss / len(train_loader)
|
||||||
|
#print(f"[{epoch+1}/{num_epoch}] Train Loss: {self.train_loss:.4f}", end='')
|
||||||
|
|
||||||
|
if valid_loader is not None:
|
||||||
|
self.eval()
|
||||||
|
valid_loss = 0.0
|
||||||
|
with torch.no_grad():
|
||||||
|
for inputs, targets in valid_loader:
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
targets = targets.to(device)
|
||||||
|
|
||||||
|
logits = self(inputs)
|
||||||
|
logits = logits.view(-1, logits.size(-1))
|
||||||
|
targets = targets.view(-1)
|
||||||
|
|
||||||
|
loss = F.cross_entropy(logits, targets)
|
||||||
|
valid_loss += loss.item()
|
||||||
|
|
||||||
|
self.validation_loss = valid_loss / len(valid_loader)
|
||||||
|
#print(f" | Val Loss: {self.validation_loss:.4f}")
|
||||||
130
tests/test_gpt_training.py
Normal file
130
tests/test_gpt_training.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from simple_llm.transformer.gpt import GPT
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataset(Dataset):
|
||||||
|
"""Тестовый датасет для проверки обучения"""
|
||||||
|
def __init__(self, vocab_size=100, seq_len=10, num_samples=100):
|
||||||
|
self.data = torch.randint(0, vocab_size, (num_samples, seq_len))
|
||||||
|
self.targets = torch.roll(self.data, -1, dims=1)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
return self.data[idx], self.targets[idx]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gpt_model():
|
||||||
|
"""Фикстура для создания тестовой модели GPT"""
|
||||||
|
return GPT(
|
||||||
|
vocab_size=100,
|
||||||
|
max_seq_len=10,
|
||||||
|
emb_size=64,
|
||||||
|
num_heads=4,
|
||||||
|
head_size=16,
|
||||||
|
num_layers=2
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def data_loaders():
|
||||||
|
"""Фикстура для создания тестовых DataLoader"""
|
||||||
|
dataset = DummyDataset()
|
||||||
|
train_loader = DataLoader(dataset, batch_size=10)
|
||||||
|
valid_loader = DataLoader(dataset, batch_size=10)
|
||||||
|
return train_loader, valid_loader
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_basic(gpt_model, data_loaders):
|
||||||
|
"""Тестирование базовой работы метода fit"""
|
||||||
|
train_loader, valid_loader = data_loaders
|
||||||
|
|
||||||
|
# Запускаем обучение
|
||||||
|
gpt_model.fit(
|
||||||
|
train_loader=train_loader,
|
||||||
|
valid_loader=valid_loader,
|
||||||
|
num_epoch=2,
|
||||||
|
learning_rate=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверяем что атрибуты loss были установлены
|
||||||
|
assert hasattr(gpt_model, 'train_loss'), "train_loss не был сохранен"
|
||||||
|
assert hasattr(gpt_model, 'validation_loss'), "validation_loss не был сохранен"
|
||||||
|
assert isinstance(gpt_model.train_loss, float), "train_loss должен быть числом"
|
||||||
|
assert isinstance(gpt_model.validation_loss, float), "validation_loss должен быть числом"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||||
|
def test_fit_device(gpt_model, data_loaders, device):
|
||||||
|
"""Тестирование работы fit на разных устройствах"""
|
||||||
|
if device == "cuda" and not torch.cuda.is_available():
|
||||||
|
pytest.skip("CUDA не доступен для тестирования")
|
||||||
|
|
||||||
|
train_loader, valid_loader = data_loaders
|
||||||
|
|
||||||
|
# Переносим модель на нужное устройство
|
||||||
|
gpt_model._device = device
|
||||||
|
gpt_model.to(device)
|
||||||
|
|
||||||
|
gpt_model.fit(
|
||||||
|
train_loader=train_loader,
|
||||||
|
valid_loader=valid_loader,
|
||||||
|
num_epoch=1,
|
||||||
|
learning_rate=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверяем что модель действительно на нужном устройстве
|
||||||
|
assert next(gpt_model.parameters()).device.type == device
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_loss_decrease(gpt_model, data_loaders):
|
||||||
|
"""Тестирование что веса изменяются в процессе обучения"""
|
||||||
|
train_loader, valid_loader = data_loaders
|
||||||
|
|
||||||
|
# Сохраняем начальные веса
|
||||||
|
initial_state = {k: v.clone() for k, v in gpt_model.state_dict().items()}
|
||||||
|
|
||||||
|
gpt_model.fit(
|
||||||
|
train_loader=train_loader,
|
||||||
|
valid_loader=valid_loader,
|
||||||
|
num_epoch=3,
|
||||||
|
learning_rate=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
# Проверяем что веса изменились
|
||||||
|
changed = False
|
||||||
|
for k in initial_state:
|
||||||
|
if not torch.allclose(initial_state[k], gpt_model.state_dict()[k]):
|
||||||
|
changed = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert changed, "Веса модели не изменились после обучения"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fit_without_validation(gpt_model, data_loaders):
|
||||||
|
"""Тестирование работы без валидационного набора"""
|
||||||
|
train_loader, _ = data_loaders
|
||||||
|
|
||||||
|
gpt_model.fit(
|
||||||
|
train_loader=train_loader,
|
||||||
|
valid_loader=None,
|
||||||
|
num_epoch=1,
|
||||||
|
learning_rate=0.001
|
||||||
|
)
|
||||||
|
|
||||||
|
assert hasattr(gpt_model, 'train_loss')
|
||||||
|
assert gpt_model.validation_loss is None
|
||||||
|
|
||||||
|
def test_fit_with_invalid_train_data(gpt_model):
|
||||||
|
"""Тестирование обработки невалидных train данных"""
|
||||||
|
with pytest.raises(ValueError, match="train_loader не может быть None"):
|
||||||
|
gpt_model.fit(
|
||||||
|
train_loader=None,
|
||||||
|
valid_loader=None,
|
||||||
|
num_epoch=1,
|
||||||
|
learning_rate=0.001
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user