Files
simple-llm/tests/test_get_data.py
Sergey Penkovsky c56a3e80c9 Реализация класса GetData
- Добавлен класс GetData для работы с последовательными данными
- Реализован функционал:
  * Создание датасета из последовательности
  * Автоматическое формирование пар (input, target)
  * Поддержка CPU/GPU
  * Проверка корректности параметров
- Добавлены тесты для проверки функционала
- Создан пример использования в example/
- Добавлена документация с блок-схемой в doc/
- Обновлен README.md с информацией о новом классе
2025-07-22 17:10:28 +03:00

102 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pytest
import torch
import sys
import os
from simple_llm.data.get_data import GetData
class TestGetData:
"""Набор тестов для проверки класса GetData"""
@pytest.fixture
def sample_data(self):
"""Фикстура с тестовыми данными: последовательность чисел 0-99"""
return list(range(100))
def test_initialization(self, sample_data):
"""Тест корректности инициализации класса"""
seq_len = 10
dataset = GetData(data=sample_data, seq_len=seq_len)
assert dataset._data == sample_data
assert dataset._seq_len == seq_len
assert dataset._device == "cpu"
# Проверка инициализации с явным указанием устройства
dataset_gpu = GetData(data=sample_data, seq_len=seq_len, device="cuda")
assert dataset_gpu._device == "cuda"
def test_dataset_length(self, sample_data):
"""Тест корректного вычисления длины датасета"""
test_cases = [
(10, 89), # seq_len=10 → len=100-10-1=89
(50, 49), # seq_len=50 → len=100-50-1=49
(99, 0) # seq_len=99 → len=100-99-1=0
]
for seq_len, expected_len in test_cases:
dataset = GetData(data=sample_data, seq_len=seq_len)
assert len(dataset) == expected_len
def test_item_retrieval(self, sample_data):
"""Тест получения элементов датасета"""
seq_len = 5
dataset = GetData(data=sample_data, seq_len=seq_len)
# Проверка первых элементов
x, y = dataset[0]
assert torch.equal(x, torch.tensor([0, 1, 2, 3, 4]))
assert torch.equal(y, torch.tensor([1, 2, 3, 4, 5]))
# Проверка элементов из середины
x, y = dataset[50]
assert torch.equal(x, torch.tensor([50, 51, 52, 53, 54]))
assert torch.equal(y, torch.tensor([51, 52, 53, 54, 55]))
# Проверка последнего элемента
last_idx = len(dataset) - 1
x, y = dataset[last_idx]
expected_x = sample_data[last_idx:last_idx+seq_len]
expected_y = sample_data[last_idx+1:last_idx+seq_len+1]
assert torch.equal(x, torch.tensor(expected_x))
assert torch.equal(y, torch.tensor(expected_y))
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Требуется GPU")
def test_gpu_support(self, sample_data):
"""Тест работы с GPU (только если доступен CUDA)"""
seq_len = 10
dataset = GetData(data=sample_data, seq_len=seq_len, device="cuda")
x, y = dataset[0]
assert x.is_cuda
assert y.is_cuda
assert x.device == torch.device("cuda")
assert y.device == torch.device("cuda")
def test_edge_cases(self):
"""Тест обработки граничных случаев"""
# Слишком длинная последовательность
with pytest.raises(ValueError):
GetData(data=[1, 2, 3], seq_len=4)
# Отрицательная длина последовательности
with pytest.raises(ValueError):
GetData(data=[1, 2, 3], seq_len=-1)
# Пустые входные данные
with pytest.raises(ValueError):
GetData(data=[], seq_len=1)
def test_tensor_conversion(self, sample_data):
"""Тест корректности преобразования в тензоры"""
seq_len = 3
dataset = GetData(data=sample_data, seq_len=seq_len)
x, y = dataset[10]
assert isinstance(x, torch.Tensor)
assert isinstance(y, torch.Tensor)
assert x.dtype == torch.int64
assert y.dtype == torch.int64
if __name__ == "__main__":
pytest.main(["-v", "--tb=native"])