mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация класса GetData
- Добавлен класс GetData для работы с последовательными данными - Реализован функционал: * Создание датасета из последовательности * Автоматическое формирование пар (input, target) * Поддержка CPU/GPU * Проверка корректности параметров - Добавлены тесты для проверки функционала - Создан пример использования в example/ - Добавлена документация с блок-схемой в doc/ - Обновлен README.md с информацией о новом классе
This commit is contained in:
12
README.md
12
README.md
@@ -28,6 +28,17 @@ python example/example_gpt.py
|
||||
|
||||
## 🧠 Основные компоненты
|
||||
|
||||
### Обработка данных
|
||||
```python
|
||||
from simple_llm.data.get_data import GetData
|
||||
|
||||
dataset = GetData(
|
||||
data=[1, 2, 3, 4, 5], # Входная последовательность
|
||||
seq_len=3, # Длина окна
|
||||
device="cuda" # Устройство (опционально)
|
||||
)
|
||||
```
|
||||
|
||||
### Модель GPT
|
||||
```python
|
||||
from simple_llm.transformer.gpt import GPT
|
||||
@@ -57,6 +68,7 @@ output = model.generate(
|
||||
Полная документация доступна в [doc/](./doc/):
|
||||
- [Архитектура GPT](./doc/gpt_documentation_ru.md)
|
||||
- [Алгоритм BPE](./doc/bpe_algorithm.md)
|
||||
- [Обработка последовательностей](./doc/get_data_documentation_ru.md)
|
||||
- [Примеры использования](./example/)
|
||||
|
||||
## 🛠 Тестирование
|
||||
|
||||
79
doc/get_data_documentation_ru.md
Normal file
79
doc/get_data_documentation_ru.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# Документация по классу GetData
|
||||
|
||||
## Назначение
|
||||
Класс `GetData` предназначен для создания датасетов из последовательных данных для обучения языковых моделей и других задач, работающих с последовательностями.
|
||||
|
||||
## Основные возможности
|
||||
- Преобразование последовательности данных в обучающие пары (input, target)
|
||||
- Поддержка различных типов данных (числа, токены)
|
||||
- Автоматический сдвиг целевой последовательности
|
||||
- Поддержка работы на CPU/GPU
|
||||
- Проверка корректности входных параметров
|
||||
|
||||
## Алгоритм работы
|
||||
1. Принимает на вход последовательность данных и длину окна
|
||||
2. Скользящим окном проходит по последовательности
|
||||
3. Для каждой позиции создает пару:
|
||||
- Входная последовательность: `data[pos:pos+seq_len]`
|
||||
- Целевая последовательность: `data[pos+1:pos+seq_len+1]` (сдвиг на 1 элемент)
|
||||
4. Преобразует данные в тензоры PyTorch
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
A[Начало] --> B[Проверка параметров]
|
||||
B -->|seq_len <= 0| C[Ошибка: отрицательная длина]
|
||||
B -->|seq_len >= len(data)| D[Ошибка: слишком длинное окно]
|
||||
B -->|Параметры верны| E[Инициализация датасета]
|
||||
E --> F[Для каждого индекса i от 0 до len(data)-seq_len-1]
|
||||
F --> G[Входной тензор: data[i:i+seq_len]]
|
||||
G --> H[Целевой тензор: data[i+1:i+seq_len+1]]
|
||||
H --> I[Преобразование в тензоры PyTorch]
|
||||
I --> J[Возврат пары тензоров]
|
||||
J --> F
|
||||
F -->|Все индексы обработаны| K[Конец]
|
||||
```
|
||||
|
||||
## Пример использования
|
||||
```python
|
||||
from simple_llm.data.get_data import GetData
|
||||
|
||||
data = list(range(10)) # Последовательность 0-9
|
||||
seq_len = 3
|
||||
dataset = GetData(data=data, seq_len=seq_len)
|
||||
|
||||
# Получение первого примера
|
||||
x, y = dataset[0]
|
||||
print(f"Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||
# Вывод: Вход: [0, 1, 2] → Цель: [1, 2, 3]
|
||||
```
|
||||
|
||||
## Параметры класса
|
||||
- `data` (list): Входная последовательность данных
|
||||
- `seq_len` (int): Длина окна последовательности
|
||||
- `device` (str): Устройство для тензоров ('cpu' или 'cuda')
|
||||
|
||||
## Методы
|
||||
- `__len__()`: Возвращает количество обучающих примеров
|
||||
- `__getitem__(idx)`: Возвращает пару тензоров по индексу
|
||||
|
||||
## Ошибки
|
||||
- `ValueError`: Если `seq_len` <= 0 или >= длины данных
|
||||
|
||||
## Применение
|
||||
1. Обучение языковых моделей
|
||||
2. Прогнозирование временных рядов
|
||||
3. Любые задачи, требующие работы с последовательностями
|
||||
|
||||
## Рекомендации
|
||||
- Для текстовых данных предварительно токенизируйте текст
|
||||
- Для больших датасетов используйте GPU (device='cuda')
|
||||
- Подбирайте seq_len в зависимости от задачи
|
||||
|
||||
## Пример с текстовыми данными
|
||||
```python
|
||||
text_tokens = [10, 20, 30, 40] # Токенизированный текст
|
||||
dataset = GetData(text_tokens, seq_len=2)
|
||||
x, y = dataset[1]
|
||||
print(f"Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||
# Вывод: Вход: [20, 30] → Цель: [30, 40]
|
||||
```
|
||||
57
example/example_get_data.py
Normal file
57
example/example_get_data.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""
|
||||
Пример использования класса GetData для работы с последовательными данными.
|
||||
|
||||
Этот пример показывает:
|
||||
1. Как создать датасет из последовательности чисел
|
||||
2. Как получить пары (вход, цель) для обучения
|
||||
3. Как работать с разными длинами последовательностей
|
||||
4. Как использовать GPU (если доступен)
|
||||
"""
|
||||
|
||||
from simple_llm.data.get_data import GetData
|
||||
import torch
|
||||
|
||||
def main():
|
||||
# 1. Простейший пример с последовательностью чисел
|
||||
print("\n=== Пример 1: Базовая последовательность ===")
|
||||
data = list(range(10)) # [0, 1, 2, ..., 9]
|
||||
seq_len = 3
|
||||
dataset = GetData(data=data, seq_len=seq_len)
|
||||
|
||||
print(f"Длина датасета: {len(dataset)}")
|
||||
for i in range(min(3, len(dataset))): # Покажем первые 3 примера
|
||||
x, y = dataset[i]
|
||||
print(f"Пример {i}:")
|
||||
print(f" Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||
|
||||
# 2. Пример с текстовыми данными (последовательность токенов)
|
||||
print("\n=== Пример 2: Токенизированный текст ===")
|
||||
text_tokens = [10, 20, 30, 40, 50, 60, 70] # Пример токенов
|
||||
text_seq_len = 2
|
||||
text_dataset = GetData(data=text_tokens, seq_len=text_seq_len)
|
||||
|
||||
print(f"Длина датасета: {len(text_dataset)}")
|
||||
for i in range(len(text_dataset)):
|
||||
x, y = text_dataset[i]
|
||||
print(f"Пример {i}: {x.tolist()} → {y.tolist()}")
|
||||
|
||||
# 3. Пример с использованием GPU (если доступен)
|
||||
print("\n=== Пример 3: Работа с GPU ===")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
print(f"Используемое устройство: {device}")
|
||||
|
||||
gpu_dataset = GetData(data=data, seq_len=seq_len, device=device)
|
||||
x, y = gpu_dataset[0]
|
||||
print(f"Пример на {device}:")
|
||||
print(f" Вход: {x.tolist()} (устройство: {x.device})")
|
||||
print(f" Цель: {y.tolist()} (устройство: {y.device})")
|
||||
|
||||
# 4. Пример обработки ошибок
|
||||
print("\n=== Пример 4: Обработка ошибок ===")
|
||||
try:
|
||||
GetData(data=[1, 2, 3], seq_len=4) # Слишком длинная последовательность
|
||||
except ValueError as e:
|
||||
print(f"Ошибка: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
simple_llm/data/__init__.py
Normal file
0
simple_llm/data/__init__.py
Normal file
77
simple_llm/data/get_data.py
Normal file
77
simple_llm/data/get_data.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
class GetData(Dataset):
|
||||
"""
|
||||
Класс для создания датасета последовательных данных для обучения языковых моделей.
|
||||
|
||||
Наследуется от torch.utils.data.Dataset и реализует:
|
||||
- Скользящее окно по последовательности данных
|
||||
- Автоматическое разделение на входные и целевые последовательности
|
||||
- Поддержку работы на CPU/GPU
|
||||
- Проверку корректности параметров
|
||||
|
||||
Args:
|
||||
data (List): Обучающая последовательность (список чисел или токенов)
|
||||
seq_len (int): Длина одной обучающей последовательности (в элементах).
|
||||
Должна быть положительной и меньше длины данных.
|
||||
device (str, optional): Устройство для тензоров ('cpu' или 'cuda'). По умолчанию 'cpu'.
|
||||
|
||||
Raises:
|
||||
ValueError: Если seq_len <= 0 или seq_len >= len(data)
|
||||
|
||||
Attributes:
|
||||
_data (List): Хранит входную последовательность
|
||||
_seq_len (int): Длина последовательности для обучения
|
||||
_device (str): Устройство для вычислений
|
||||
|
||||
Examples:
|
||||
>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||
>>> dataset = GetData(data, seq_len=3)
|
||||
>>> len(dataset)
|
||||
6
|
||||
>>> dataset[0]
|
||||
(tensor([1, 2, 3]), tensor([2, 3, 4]))
|
||||
|
||||
# Некорректные параметры
|
||||
>>> GetData(data=[1, 2, 3], seq_len=4) # Вызовет ValueError
|
||||
>>> GetData(data=[1, 2, 3], seq_len=-1) # Вызовет ValueError
|
||||
"""
|
||||
|
||||
def __init__(self, data: list, seq_len: int, device: str = "cpu") -> None:
|
||||
"""Инициализация датасета с последовательными данными."""
|
||||
if seq_len <= 0:
|
||||
raise ValueError(f"Sequence length must be positive, got {seq_len}")
|
||||
if seq_len >= len(data):
|
||||
raise ValueError(f"Sequence length {seq_len} must be less than data length {len(data)}")
|
||||
self._data = data
|
||||
self._seq_len = seq_len
|
||||
self._device = device
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Возвращает количество обучающих примеров в датасете.
|
||||
|
||||
Формула:
|
||||
N - seq_len - 1
|
||||
где N - длина всей последовательности
|
||||
|
||||
Returns:
|
||||
int: Количество доступных последовательностей
|
||||
"""
|
||||
return len(self._data) - self._seq_len - 1
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Возвращает один обучающий пример по индексу.
|
||||
|
||||
Args:
|
||||
idx (int): Позиция начала последовательности
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Пара (входная_последовательность, целевая_последовательность)
|
||||
где целевая последовательность сдвинута на 1 элемент вперед
|
||||
"""
|
||||
x = torch.tensor(self._data[idx:idx+self._seq_len]).to(self._device)
|
||||
y = torch.tensor(self._data[idx+1:idx+self._seq_len+1]).to(self._device)
|
||||
return (x, y)
|
||||
0
simple_llm/embedding/__init__.py
Normal file
0
simple_llm/embedding/__init__.py
Normal file
0
simple_llm/transformer/__init__.py
Normal file
0
simple_llm/transformer/__init__.py
Normal file
102
tests/test_get_data.py
Normal file
102
tests/test_get_data.py
Normal file
@@ -0,0 +1,102 @@
|
||||
import pytest
|
||||
import torch
|
||||
import sys
|
||||
import os
|
||||
|
||||
from simple_llm.data.get_data import GetData
|
||||
|
||||
class TestGetData:
|
||||
"""Набор тестов для проверки класса GetData"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data(self):
|
||||
"""Фикстура с тестовыми данными: последовательность чисел 0-99"""
|
||||
return list(range(100))
|
||||
|
||||
def test_initialization(self, sample_data):
|
||||
"""Тест корректности инициализации класса"""
|
||||
seq_len = 10
|
||||
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||
|
||||
assert dataset._data == sample_data
|
||||
assert dataset._seq_len == seq_len
|
||||
assert dataset._device == "cpu"
|
||||
|
||||
# Проверка инициализации с явным указанием устройства
|
||||
dataset_gpu = GetData(data=sample_data, seq_len=seq_len, device="cuda")
|
||||
assert dataset_gpu._device == "cuda"
|
||||
|
||||
def test_dataset_length(self, sample_data):
|
||||
"""Тест корректного вычисления длины датасета"""
|
||||
test_cases = [
|
||||
(10, 89), # seq_len=10 → len=100-10-1=89
|
||||
(50, 49), # seq_len=50 → len=100-50-1=49
|
||||
(99, 0) # seq_len=99 → len=100-99-1=0
|
||||
]
|
||||
|
||||
for seq_len, expected_len in test_cases:
|
||||
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||
assert len(dataset) == expected_len
|
||||
|
||||
def test_item_retrieval(self, sample_data):
|
||||
"""Тест получения элементов датасета"""
|
||||
seq_len = 5
|
||||
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||
|
||||
# Проверка первых элементов
|
||||
x, y = dataset[0]
|
||||
assert torch.equal(x, torch.tensor([0, 1, 2, 3, 4]))
|
||||
assert torch.equal(y, torch.tensor([1, 2, 3, 4, 5]))
|
||||
|
||||
# Проверка элементов из середины
|
||||
x, y = dataset[50]
|
||||
assert torch.equal(x, torch.tensor([50, 51, 52, 53, 54]))
|
||||
assert torch.equal(y, torch.tensor([51, 52, 53, 54, 55]))
|
||||
|
||||
# Проверка последнего элемента
|
||||
last_idx = len(dataset) - 1
|
||||
x, y = dataset[last_idx]
|
||||
expected_x = sample_data[last_idx:last_idx+seq_len]
|
||||
expected_y = sample_data[last_idx+1:last_idx+seq_len+1]
|
||||
assert torch.equal(x, torch.tensor(expected_x))
|
||||
assert torch.equal(y, torch.tensor(expected_y))
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Требуется GPU")
|
||||
def test_gpu_support(self, sample_data):
|
||||
"""Тест работы с GPU (только если доступен CUDA)"""
|
||||
seq_len = 10
|
||||
dataset = GetData(data=sample_data, seq_len=seq_len, device="cuda")
|
||||
x, y = dataset[0]
|
||||
|
||||
assert x.is_cuda
|
||||
assert y.is_cuda
|
||||
assert x.device == torch.device("cuda")
|
||||
assert y.device == torch.device("cuda")
|
||||
|
||||
def test_edge_cases(self):
|
||||
"""Тест обработки граничных случаев"""
|
||||
# Слишком длинная последовательность
|
||||
with pytest.raises(ValueError):
|
||||
GetData(data=[1, 2, 3], seq_len=4)
|
||||
|
||||
# Отрицательная длина последовательности
|
||||
with pytest.raises(ValueError):
|
||||
GetData(data=[1, 2, 3], seq_len=-1)
|
||||
|
||||
# Пустые входные данные
|
||||
with pytest.raises(ValueError):
|
||||
GetData(data=[], seq_len=1)
|
||||
|
||||
def test_tensor_conversion(self, sample_data):
|
||||
"""Тест корректности преобразования в тензоры"""
|
||||
seq_len = 3
|
||||
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||
x, y = dataset[10]
|
||||
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert isinstance(y, torch.Tensor)
|
||||
assert x.dtype == torch.int64
|
||||
assert y.dtype == torch.int64
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-v", "--tb=native"])
|
||||
Reference in New Issue
Block a user