mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация класса GetData
- Добавлен класс GetData для работы с последовательными данными - Реализован функционал: * Создание датасета из последовательности * Автоматическое формирование пар (input, target) * Поддержка CPU/GPU * Проверка корректности параметров - Добавлены тесты для проверки функционала - Создан пример использования в example/ - Добавлена документация с блок-схемой в doc/ - Обновлен README.md с информацией о новом классе
This commit is contained in:
12
README.md
12
README.md
@@ -28,6 +28,17 @@ python example/example_gpt.py
|
|||||||
|
|
||||||
## 🧠 Основные компоненты
|
## 🧠 Основные компоненты
|
||||||
|
|
||||||
|
### Обработка данных
|
||||||
|
```python
|
||||||
|
from simple_llm.data.get_data import GetData
|
||||||
|
|
||||||
|
dataset = GetData(
|
||||||
|
data=[1, 2, 3, 4, 5], # Входная последовательность
|
||||||
|
seq_len=3, # Длина окна
|
||||||
|
device="cuda" # Устройство (опционально)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
### Модель GPT
|
### Модель GPT
|
||||||
```python
|
```python
|
||||||
from simple_llm.transformer.gpt import GPT
|
from simple_llm.transformer.gpt import GPT
|
||||||
@@ -57,6 +68,7 @@ output = model.generate(
|
|||||||
Полная документация доступна в [doc/](./doc/):
|
Полная документация доступна в [doc/](./doc/):
|
||||||
- [Архитектура GPT](./doc/gpt_documentation_ru.md)
|
- [Архитектура GPT](./doc/gpt_documentation_ru.md)
|
||||||
- [Алгоритм BPE](./doc/bpe_algorithm.md)
|
- [Алгоритм BPE](./doc/bpe_algorithm.md)
|
||||||
|
- [Обработка последовательностей](./doc/get_data_documentation_ru.md)
|
||||||
- [Примеры использования](./example/)
|
- [Примеры использования](./example/)
|
||||||
|
|
||||||
## 🛠 Тестирование
|
## 🛠 Тестирование
|
||||||
|
|||||||
79
doc/get_data_documentation_ru.md
Normal file
79
doc/get_data_documentation_ru.md
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
# Документация по классу GetData
|
||||||
|
|
||||||
|
## Назначение
|
||||||
|
Класс `GetData` предназначен для создания датасетов из последовательных данных для обучения языковых моделей и других задач, работающих с последовательностями.
|
||||||
|
|
||||||
|
## Основные возможности
|
||||||
|
- Преобразование последовательности данных в обучающие пары (input, target)
|
||||||
|
- Поддержка различных типов данных (числа, токены)
|
||||||
|
- Автоматический сдвиг целевой последовательности
|
||||||
|
- Поддержка работы на CPU/GPU
|
||||||
|
- Проверка корректности входных параметров
|
||||||
|
|
||||||
|
## Алгоритм работы
|
||||||
|
1. Принимает на вход последовательность данных и длину окна
|
||||||
|
2. Скользящим окном проходит по последовательности
|
||||||
|
3. Для каждой позиции создает пару:
|
||||||
|
- Входная последовательность: `data[pos:pos+seq_len]`
|
||||||
|
- Целевая последовательность: `data[pos+1:pos+seq_len+1]` (сдвиг на 1 элемент)
|
||||||
|
4. Преобразует данные в тензоры PyTorch
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
A[Начало] --> B[Проверка параметров]
|
||||||
|
B -->|seq_len <= 0| C[Ошибка: отрицательная длина]
|
||||||
|
B -->|seq_len >= len(data)| D[Ошибка: слишком длинное окно]
|
||||||
|
B -->|Параметры верны| E[Инициализация датасета]
|
||||||
|
E --> F[Для каждого индекса i от 0 до len(data)-seq_len-1]
|
||||||
|
F --> G[Входной тензор: data[i:i+seq_len]]
|
||||||
|
G --> H[Целевой тензор: data[i+1:i+seq_len+1]]
|
||||||
|
H --> I[Преобразование в тензоры PyTorch]
|
||||||
|
I --> J[Возврат пары тензоров]
|
||||||
|
J --> F
|
||||||
|
F -->|Все индексы обработаны| K[Конец]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Пример использования
|
||||||
|
```python
|
||||||
|
from simple_llm.data.get_data import GetData
|
||||||
|
|
||||||
|
data = list(range(10)) # Последовательность 0-9
|
||||||
|
seq_len = 3
|
||||||
|
dataset = GetData(data=data, seq_len=seq_len)
|
||||||
|
|
||||||
|
# Получение первого примера
|
||||||
|
x, y = dataset[0]
|
||||||
|
print(f"Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||||
|
# Вывод: Вход: [0, 1, 2] → Цель: [1, 2, 3]
|
||||||
|
```
|
||||||
|
|
||||||
|
## Параметры класса
|
||||||
|
- `data` (list): Входная последовательность данных
|
||||||
|
- `seq_len` (int): Длина окна последовательности
|
||||||
|
- `device` (str): Устройство для тензоров ('cpu' или 'cuda')
|
||||||
|
|
||||||
|
## Методы
|
||||||
|
- `__len__()`: Возвращает количество обучающих примеров
|
||||||
|
- `__getitem__(idx)`: Возвращает пару тензоров по индексу
|
||||||
|
|
||||||
|
## Ошибки
|
||||||
|
- `ValueError`: Если `seq_len` <= 0 или >= длины данных
|
||||||
|
|
||||||
|
## Применение
|
||||||
|
1. Обучение языковых моделей
|
||||||
|
2. Прогнозирование временных рядов
|
||||||
|
3. Любые задачи, требующие работы с последовательностями
|
||||||
|
|
||||||
|
## Рекомендации
|
||||||
|
- Для текстовых данных предварительно токенизируйте текст
|
||||||
|
- Для больших датасетов используйте GPU (device='cuda')
|
||||||
|
- Подбирайте seq_len в зависимости от задачи
|
||||||
|
|
||||||
|
## Пример с текстовыми данными
|
||||||
|
```python
|
||||||
|
text_tokens = [10, 20, 30, 40] # Токенизированный текст
|
||||||
|
dataset = GetData(text_tokens, seq_len=2)
|
||||||
|
x, y = dataset[1]
|
||||||
|
print(f"Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||||
|
# Вывод: Вход: [20, 30] → Цель: [30, 40]
|
||||||
|
```
|
||||||
57
example/example_get_data.py
Normal file
57
example/example_get_data.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""
|
||||||
|
Пример использования класса GetData для работы с последовательными данными.
|
||||||
|
|
||||||
|
Этот пример показывает:
|
||||||
|
1. Как создать датасет из последовательности чисел
|
||||||
|
2. Как получить пары (вход, цель) для обучения
|
||||||
|
3. Как работать с разными длинами последовательностей
|
||||||
|
4. Как использовать GPU (если доступен)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from simple_llm.data.get_data import GetData
|
||||||
|
import torch
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 1. Простейший пример с последовательностью чисел
|
||||||
|
print("\n=== Пример 1: Базовая последовательность ===")
|
||||||
|
data = list(range(10)) # [0, 1, 2, ..., 9]
|
||||||
|
seq_len = 3
|
||||||
|
dataset = GetData(data=data, seq_len=seq_len)
|
||||||
|
|
||||||
|
print(f"Длина датасета: {len(dataset)}")
|
||||||
|
for i in range(min(3, len(dataset))): # Покажем первые 3 примера
|
||||||
|
x, y = dataset[i]
|
||||||
|
print(f"Пример {i}:")
|
||||||
|
print(f" Вход: {x.tolist()} → Цель: {y.tolist()}")
|
||||||
|
|
||||||
|
# 2. Пример с текстовыми данными (последовательность токенов)
|
||||||
|
print("\n=== Пример 2: Токенизированный текст ===")
|
||||||
|
text_tokens = [10, 20, 30, 40, 50, 60, 70] # Пример токенов
|
||||||
|
text_seq_len = 2
|
||||||
|
text_dataset = GetData(data=text_tokens, seq_len=text_seq_len)
|
||||||
|
|
||||||
|
print(f"Длина датасета: {len(text_dataset)}")
|
||||||
|
for i in range(len(text_dataset)):
|
||||||
|
x, y = text_dataset[i]
|
||||||
|
print(f"Пример {i}: {x.tolist()} → {y.tolist()}")
|
||||||
|
|
||||||
|
# 3. Пример с использованием GPU (если доступен)
|
||||||
|
print("\n=== Пример 3: Работа с GPU ===")
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
print(f"Используемое устройство: {device}")
|
||||||
|
|
||||||
|
gpu_dataset = GetData(data=data, seq_len=seq_len, device=device)
|
||||||
|
x, y = gpu_dataset[0]
|
||||||
|
print(f"Пример на {device}:")
|
||||||
|
print(f" Вход: {x.tolist()} (устройство: {x.device})")
|
||||||
|
print(f" Цель: {y.tolist()} (устройство: {y.device})")
|
||||||
|
|
||||||
|
# 4. Пример обработки ошибок
|
||||||
|
print("\n=== Пример 4: Обработка ошибок ===")
|
||||||
|
try:
|
||||||
|
GetData(data=[1, 2, 3], seq_len=4) # Слишком длинная последовательность
|
||||||
|
except ValueError as e:
|
||||||
|
print(f"Ошибка: {e}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
simple_llm/data/__init__.py
Normal file
0
simple_llm/data/__init__.py
Normal file
77
simple_llm/data/get_data.py
Normal file
77
simple_llm/data/get_data.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
class GetData(Dataset):
|
||||||
|
"""
|
||||||
|
Класс для создания датасета последовательных данных для обучения языковых моделей.
|
||||||
|
|
||||||
|
Наследуется от torch.utils.data.Dataset и реализует:
|
||||||
|
- Скользящее окно по последовательности данных
|
||||||
|
- Автоматическое разделение на входные и целевые последовательности
|
||||||
|
- Поддержку работы на CPU/GPU
|
||||||
|
- Проверку корректности параметров
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (List): Обучающая последовательность (список чисел или токенов)
|
||||||
|
seq_len (int): Длина одной обучающей последовательности (в элементах).
|
||||||
|
Должна быть положительной и меньше длины данных.
|
||||||
|
device (str, optional): Устройство для тензоров ('cpu' или 'cuda'). По умолчанию 'cpu'.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Если seq_len <= 0 или seq_len >= len(data)
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
_data (List): Хранит входную последовательность
|
||||||
|
_seq_len (int): Длина последовательности для обучения
|
||||||
|
_device (str): Устройство для вычислений
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
||||||
|
>>> dataset = GetData(data, seq_len=3)
|
||||||
|
>>> len(dataset)
|
||||||
|
6
|
||||||
|
>>> dataset[0]
|
||||||
|
(tensor([1, 2, 3]), tensor([2, 3, 4]))
|
||||||
|
|
||||||
|
# Некорректные параметры
|
||||||
|
>>> GetData(data=[1, 2, 3], seq_len=4) # Вызовет ValueError
|
||||||
|
>>> GetData(data=[1, 2, 3], seq_len=-1) # Вызовет ValueError
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, data: list, seq_len: int, device: str = "cpu") -> None:
|
||||||
|
"""Инициализация датасета с последовательными данными."""
|
||||||
|
if seq_len <= 0:
|
||||||
|
raise ValueError(f"Sequence length must be positive, got {seq_len}")
|
||||||
|
if seq_len >= len(data):
|
||||||
|
raise ValueError(f"Sequence length {seq_len} must be less than data length {len(data)}")
|
||||||
|
self._data = data
|
||||||
|
self._seq_len = seq_len
|
||||||
|
self._device = device
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""
|
||||||
|
Возвращает количество обучающих примеров в датасете.
|
||||||
|
|
||||||
|
Формула:
|
||||||
|
N - seq_len - 1
|
||||||
|
где N - длина всей последовательности
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Количество доступных последовательностей
|
||||||
|
"""
|
||||||
|
return len(self._data) - self._seq_len - 1
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Возвращает один обучающий пример по индексу.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Позиция начала последовательности
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Пара (входная_последовательность, целевая_последовательность)
|
||||||
|
где целевая последовательность сдвинута на 1 элемент вперед
|
||||||
|
"""
|
||||||
|
x = torch.tensor(self._data[idx:idx+self._seq_len]).to(self._device)
|
||||||
|
y = torch.tensor(self._data[idx+1:idx+self._seq_len+1]).to(self._device)
|
||||||
|
return (x, y)
|
||||||
0
simple_llm/embedding/__init__.py
Normal file
0
simple_llm/embedding/__init__.py
Normal file
0
simple_llm/transformer/__init__.py
Normal file
0
simple_llm/transformer/__init__.py
Normal file
102
tests/test_get_data.py
Normal file
102
tests/test_get_data.py
Normal file
@@ -0,0 +1,102 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
from simple_llm.data.get_data import GetData
|
||||||
|
|
||||||
|
class TestGetData:
|
||||||
|
"""Набор тестов для проверки класса GetData"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_data(self):
|
||||||
|
"""Фикстура с тестовыми данными: последовательность чисел 0-99"""
|
||||||
|
return list(range(100))
|
||||||
|
|
||||||
|
def test_initialization(self, sample_data):
|
||||||
|
"""Тест корректности инициализации класса"""
|
||||||
|
seq_len = 10
|
||||||
|
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||||
|
|
||||||
|
assert dataset._data == sample_data
|
||||||
|
assert dataset._seq_len == seq_len
|
||||||
|
assert dataset._device == "cpu"
|
||||||
|
|
||||||
|
# Проверка инициализации с явным указанием устройства
|
||||||
|
dataset_gpu = GetData(data=sample_data, seq_len=seq_len, device="cuda")
|
||||||
|
assert dataset_gpu._device == "cuda"
|
||||||
|
|
||||||
|
def test_dataset_length(self, sample_data):
|
||||||
|
"""Тест корректного вычисления длины датасета"""
|
||||||
|
test_cases = [
|
||||||
|
(10, 89), # seq_len=10 → len=100-10-1=89
|
||||||
|
(50, 49), # seq_len=50 → len=100-50-1=49
|
||||||
|
(99, 0) # seq_len=99 → len=100-99-1=0
|
||||||
|
]
|
||||||
|
|
||||||
|
for seq_len, expected_len in test_cases:
|
||||||
|
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||||
|
assert len(dataset) == expected_len
|
||||||
|
|
||||||
|
def test_item_retrieval(self, sample_data):
|
||||||
|
"""Тест получения элементов датасета"""
|
||||||
|
seq_len = 5
|
||||||
|
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||||
|
|
||||||
|
# Проверка первых элементов
|
||||||
|
x, y = dataset[0]
|
||||||
|
assert torch.equal(x, torch.tensor([0, 1, 2, 3, 4]))
|
||||||
|
assert torch.equal(y, torch.tensor([1, 2, 3, 4, 5]))
|
||||||
|
|
||||||
|
# Проверка элементов из середины
|
||||||
|
x, y = dataset[50]
|
||||||
|
assert torch.equal(x, torch.tensor([50, 51, 52, 53, 54]))
|
||||||
|
assert torch.equal(y, torch.tensor([51, 52, 53, 54, 55]))
|
||||||
|
|
||||||
|
# Проверка последнего элемента
|
||||||
|
last_idx = len(dataset) - 1
|
||||||
|
x, y = dataset[last_idx]
|
||||||
|
expected_x = sample_data[last_idx:last_idx+seq_len]
|
||||||
|
expected_y = sample_data[last_idx+1:last_idx+seq_len+1]
|
||||||
|
assert torch.equal(x, torch.tensor(expected_x))
|
||||||
|
assert torch.equal(y, torch.tensor(expected_y))
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Требуется GPU")
|
||||||
|
def test_gpu_support(self, sample_data):
|
||||||
|
"""Тест работы с GPU (только если доступен CUDA)"""
|
||||||
|
seq_len = 10
|
||||||
|
dataset = GetData(data=sample_data, seq_len=seq_len, device="cuda")
|
||||||
|
x, y = dataset[0]
|
||||||
|
|
||||||
|
assert x.is_cuda
|
||||||
|
assert y.is_cuda
|
||||||
|
assert x.device == torch.device("cuda")
|
||||||
|
assert y.device == torch.device("cuda")
|
||||||
|
|
||||||
|
def test_edge_cases(self):
|
||||||
|
"""Тест обработки граничных случаев"""
|
||||||
|
# Слишком длинная последовательность
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GetData(data=[1, 2, 3], seq_len=4)
|
||||||
|
|
||||||
|
# Отрицательная длина последовательности
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GetData(data=[1, 2, 3], seq_len=-1)
|
||||||
|
|
||||||
|
# Пустые входные данные
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
GetData(data=[], seq_len=1)
|
||||||
|
|
||||||
|
def test_tensor_conversion(self, sample_data):
|
||||||
|
"""Тест корректности преобразования в тензоры"""
|
||||||
|
seq_len = 3
|
||||||
|
dataset = GetData(data=sample_data, seq_len=seq_len)
|
||||||
|
x, y = dataset[10]
|
||||||
|
|
||||||
|
assert isinstance(x, torch.Tensor)
|
||||||
|
assert isinstance(y, torch.Tensor)
|
||||||
|
assert x.dtype == torch.int64
|
||||||
|
assert y.dtype == torch.int64
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", "--tb=native"])
|
||||||
Reference in New Issue
Block a user