mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
doc(datasets): update docstrings and tests
This commit is contained in:
@@ -16,7 +16,7 @@ import torch
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from llm.tokenizers import BPETokenizer
|
||||
from llm.training.dataset import TextDataset
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
from llm.training.trainer import Trainer
|
||||
|
||||
from shared.data import (
|
||||
|
||||
0
llm/src/llm/datasets/__init__.py
Normal file
0
llm/src/llm/datasets/__init__.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class StreamingTextDataset(Dataset):
|
||||
"""
|
||||
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
|
||||
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
|
||||
- Поддерживает стандартный DataLoader PyTorch.
|
||||
|
||||
Ключевые особенности:
|
||||
---------------------
|
||||
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
|
||||
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
|
||||
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
|
||||
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
|
||||
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
|
||||
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
|
||||
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
|
||||
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
|
||||
>>> for batch in loader:
|
||||
... print(batch['input_ids'].shape) # torch.Size([8, 512])
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
|
||||
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
|
||||
- Не работает с файлами напрямую, только со строками (списком).
|
||||
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
|
||||
|
||||
References:
|
||||
-----------
|
||||
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
|
||||
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация StreamingTextDataset из списка строк.
|
||||
|
||||
Аргументы:
|
||||
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
|
||||
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
|
||||
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
|
||||
|
||||
Особенности:
|
||||
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
|
||||
- Каждый пример автоматически дополняется или усекается до block_size.
|
||||
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
|
||||
|
||||
Пример:
|
||||
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
|
||||
>>> for ex in ds:
|
||||
... print(ex['input_ids'].shape) # torch.Size([256])
|
||||
"""
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
# Получаем pad_token_id из токенизатора
|
||||
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество доступных примеров в датасете.
|
||||
|
||||
Returns:
|
||||
int: Число примеров (равно длине исходного списка строк).
|
||||
"""
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить обработанный пример по индексу из потокового датасета.
|
||||
|
||||
Аргументы:
|
||||
idx (int): Индекс примера в исходном списке строк.
|
||||
|
||||
Возвращает:
|
||||
dict: Словарь с тензорами для обучения LLM:
|
||||
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
|
||||
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
|
||||
|
||||
Пример:
|
||||
>>> item = dataset[10]
|
||||
>>> assert isinstance(item, dict)
|
||||
>>> assert item['input_ids'].shape == (block_size,)
|
||||
>>> assert 'labels' in item
|
||||
"""
|
||||
text = self.texts[idx]
|
||||
|
||||
# Токенизация на лету
|
||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > self.block_size:
|
||||
input_ids = input_ids[: self.block_size]
|
||||
else:
|
||||
input_ids = input_ids + [self.pad_token_id] * (
|
||||
self.block_size - len(input_ids)
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
112
llm/src/llm/datasets/text_dataset.py
Normal file
112
llm/src/llm/datasets/text_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
|
||||
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
|
||||
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
|
||||
|
||||
Формат и аргументы конструктора:
|
||||
-------------------------------
|
||||
texts: List[str]
|
||||
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
|
||||
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
|
||||
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
|
||||
block_size: int, по умолчанию 128
|
||||
Желаемая длина выходной последовательности (padding/truncation внутри класса).
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
|
||||
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
|
||||
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
|
||||
|
||||
Пример использования:
|
||||
---------------------
|
||||
>>> with open("dataset.txt", encoding="utf-8") as f:
|
||||
... texts = f.read().splitlines()
|
||||
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
|
||||
>>> from torch.utils.data import DataLoader
|
||||
>>> loader = DataLoader(dataset, batch_size=4)
|
||||
>>> for item in loader:
|
||||
... # item['input_ids'] для обучения LLM
|
||||
|
||||
References:
|
||||
-----------
|
||||
- Torch Dataset: https://pytorch.org/docs/stable/data.html
|
||||
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация датасета из списка строк.
|
||||
|
||||
Аргументы:
|
||||
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
|
||||
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
|
||||
block_size (int, по умолчанию 128): Желаемая длина результата —
|
||||
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
|
||||
|
||||
Особенности:
|
||||
- Строки не фильтруются и не изменяются внутри датасета.
|
||||
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
|
||||
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
|
||||
|
||||
Пример:
|
||||
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
for text in texts:
|
||||
# Кодируем текст в токены
|
||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > block_size:
|
||||
input_ids = input_ids[:block_size]
|
||||
else:
|
||||
# Дополняем pad_token_id
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество примеров в датасете (длина списка текстов).
|
||||
|
||||
Returns:
|
||||
int: Число примеров в датасете.
|
||||
"""
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить пример из датасета по индексу.
|
||||
|
||||
Аргументы:
|
||||
idx (int): Индекс примера.
|
||||
|
||||
Возвращает:
|
||||
dict: Словарь с тензорами токенов для модели:
|
||||
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
|
||||
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
|
||||
|
||||
Пример:
|
||||
>>> item = dataset[7]
|
||||
>>> assert isinstance(item, dict)
|
||||
>>> assert item['input_ids'].shape == (block_size,)
|
||||
>>> assert 'labels' in item
|
||||
"""
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
|
||||
|
||||
class TextWithSpecialTokensDataset(TextDataset):
|
||||
"""
|
||||
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
|
||||
|
||||
Назначение:
|
||||
-----------
|
||||
- Работает с уже готовым списком строк (не с файлом!).
|
||||
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
|
||||
- Обрезает или дополняет каждую последовательность до длины block_size.
|
||||
|
||||
Аргументы конструктора:
|
||||
-----------------------
|
||||
texts (List[str]): Список обучающих строк (примеров).
|
||||
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
|
||||
block_size (int, default=128): Желаемая длина примера (padding/truncation).
|
||||
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
|
||||
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
|
||||
|
||||
Особенности:
|
||||
------------
|
||||
- Если pad_token_id не задан — по умолчанию паддит нулями.
|
||||
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
|
||||
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
|
||||
- Пример вызова:
|
||||
>>> texts = ["пример текста", "ещё текст"]
|
||||
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
|
||||
>>> out = ds[0]
|
||||
>>> assert out['input_ids'].shape == (16,)
|
||||
|
||||
References:
|
||||
-----------
|
||||
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
|
||||
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
tokenizer: Any,
|
||||
block_size: int = 128,
|
||||
add_bos: bool = False,
|
||||
add_eos: bool = False,
|
||||
):
|
||||
"""
|
||||
Инициализация датасета с поддержкой специальных токенов.
|
||||
|
||||
Args:
|
||||
texts (List[str]): Список строк (все ваши обучающие примеры).
|
||||
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
|
||||
block_size (int): Длина выходного примера.
|
||||
add_bos (bool): Добавлять ли BOS токен в начало.
|
||||
add_eos (bool): Добавлять ли EOS токен в конец.
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
|
||||
for text in texts:
|
||||
# Кодируем с специальными токенами
|
||||
input_ids = tokenizer.encode(
|
||||
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
|
||||
)
|
||||
|
||||
# Учитываем специальные токены при обрезке/дополнении
|
||||
effective_block_size = block_size
|
||||
if add_bos:
|
||||
effective_block_size -= 1
|
||||
if add_eos:
|
||||
effective_block_size -= 1
|
||||
|
||||
if len(input_ids) > effective_block_size:
|
||||
input_ids = input_ids[:effective_block_size]
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if (
|
||||
add_bos
|
||||
and hasattr(tokenizer, "bos_token_id")
|
||||
and tokenizer.bos_token_id is not None
|
||||
):
|
||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||
if (
|
||||
add_eos
|
||||
and hasattr(tokenizer, "eos_token_id")
|
||||
and tokenizer.eos_token_id is not None
|
||||
):
|
||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
# Дополняем до полной длины
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
if len(input_ids) < block_size:
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
"""
|
||||
Возвращает количество примеров в датасете.
|
||||
|
||||
Returns:
|
||||
int: Размер (len(self.examples)).
|
||||
"""
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Получить пример с учётом специальных токенов и паддинга.
|
||||
|
||||
Args:
|
||||
idx (int): Индекс в dataset.
|
||||
|
||||
Returns:
|
||||
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
|
||||
"""
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
@@ -178,25 +178,50 @@ class Llama(BaseModel):
|
||||
use_cache: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Генерация текста c помощью LLaMA (autoregressive Transformer).
|
||||
Поддерживается:
|
||||
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
|
||||
- кэш attention для ускорения генерации длинных последовательностей
|
||||
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||
|
||||
Args:
|
||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
||||
max_new_tokens (int): сколько новых токенов сгенерировать
|
||||
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
|
||||
temperature (float): масштаб для softmax (важно для sampling)
|
||||
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
|
||||
top_p (float|None): nucleus sampling
|
||||
use_cache (bool): ускоряет autoregressive при длинной генерации
|
||||
Returns:
|
||||
output (Tensor[int]): [batch, seq_len + max_new_tokens]
|
||||
Пример:
|
||||
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
|
||||
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
|
||||
>>> print(tokenizer.decode(generated[0]))
|
||||
Аргументы:
|
||||
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
|
||||
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
|
||||
>1.0 — менее предсказуемые, более разнообразные выборки.
|
||||
<1.0 — более строгие, консервативные выборки.
|
||||
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
|
||||
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
|
||||
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
|
||||
|
||||
Исключения:
|
||||
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
|
||||
ValueError: Если temperature ≤ 0.
|
||||
ValueError: Если одновременно заданы top_k и top_p.
|
||||
ValueError: Если top_k ≤ 0.
|
||||
ValueError: Если top_p не в диапазоне (0, 1].
|
||||
|
||||
Примеры:
|
||||
>>> # Строго жадная генерация
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||
>>> # Вероятностная генерация с температурой
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
|
||||
>>> # Top-k sampling
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||
>>> # Top-p (nucleus)
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||
>>> # Комбинация температуры и top-k
|
||||
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||
|
||||
Примечания:
|
||||
- temperature, top_k, top_p применяются только если do_sample=True.
|
||||
- Одновременное использование top_k и top_p запрещено.
|
||||
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
|
||||
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
|
||||
|
||||
Ссылки:
|
||||
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
|
||||
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||
"""
|
||||
cache = None
|
||||
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from typing import List, Any
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
"""
|
||||
Простой датасет для языкового моделирования (LLM).
|
||||
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
"""
|
||||
Инициализация датасета.
|
||||
|
||||
Args:
|
||||
texts: Список текстов для обучения
|
||||
tokenizer: Токенизатор с методами encode/decode
|
||||
block_size: Максимальная длина последовательности
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
for text in texts:
|
||||
# Кодируем текст в токены
|
||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > block_size:
|
||||
input_ids = input_ids[:block_size]
|
||||
else:
|
||||
# Дополняем pad_token_id
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
|
||||
class StreamingTextDataset(Dataset):
|
||||
"""
|
||||
Датасет для потоковой обработки больших текстов.
|
||||
Токенизация происходит на лету, что экономит память.
|
||||
"""
|
||||
|
||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||
self.texts = texts
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
|
||||
# Получаем pad_token_id из токенизатора
|
||||
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.texts)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
text = self.texts[idx]
|
||||
|
||||
# Токенизация на лету
|
||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Обрезаем или дополняем до нужной длины
|
||||
if len(input_ids) > self.block_size:
|
||||
input_ids = input_ids[: self.block_size]
|
||||
else:
|
||||
input_ids = input_ids + [self.pad_token_id] * (
|
||||
self.block_size - len(input_ids)
|
||||
)
|
||||
|
||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
|
||||
|
||||
class TextDatasetWithSpecialTokens(TextDataset):
|
||||
"""
|
||||
Расширенная версия TextDataset с поддержкой специальных токенов.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
texts: List[str],
|
||||
tokenizer: Any,
|
||||
block_size: int = 128,
|
||||
add_bos: bool = False,
|
||||
add_eos: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
texts: Список текстов
|
||||
tokenizer: Токенизатор
|
||||
block_size: Максимальная длина
|
||||
add_bos: Добавлять токен начала последовательности
|
||||
add_eos: Добавлять токен конца последовательности
|
||||
"""
|
||||
self.examples = []
|
||||
self.tokenizer = tokenizer
|
||||
self.block_size = block_size
|
||||
self.add_bos = add_bos
|
||||
self.add_eos = add_eos
|
||||
|
||||
for text in texts:
|
||||
# Кодируем с специальными токенами
|
||||
input_ids = tokenizer.encode(
|
||||
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=eos
|
||||
)
|
||||
|
||||
# Учитываем специальные токены при обрезке/дополнении
|
||||
effective_block_size = block_size
|
||||
if add_bos:
|
||||
effective_block_size -= 1
|
||||
if add_eos:
|
||||
effective_block_size -= 1
|
||||
|
||||
if len(input_ids) > effective_block_size:
|
||||
input_ids = input_ids[:effective_block_size]
|
||||
|
||||
# Добавляем специальные токены если нужно
|
||||
if (
|
||||
add_bos
|
||||
and hasattr(tokenizer, "bos_token_id")
|
||||
and tokenizer.bos_token_id is not None
|
||||
):
|
||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||
if (
|
||||
add_eos
|
||||
and hasattr(tokenizer, "eos_token_id")
|
||||
and tokenizer.eos_token_id is not None
|
||||
):
|
||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||
|
||||
# Дополняем до полной длины
|
||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||
if len(input_ids) < block_size:
|
||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||
|
||||
self.examples.append(input_ids)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||
labels = input_ids.clone()
|
||||
return {"input_ids": input_ids, "labels": labels}
|
||||
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.streaming_text_dataset import StreamingTextDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self, vocab_size=100):
|
||||
self.vocab_size = vocab_size
|
||||
def encode(self, text, **kwargs):
|
||||
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||
|
||||
def test_streaming_textdataset_basic_shape():
|
||||
texts = ["hello world", "big transformers are fun", "LLM test string"]
|
||||
tokenizer = DummyTokenizer(50)
|
||||
block_size = 7
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
assert len(ds) == 3
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "input_ids" in item
|
||||
assert item["input_ids"].shape == (block_size,)
|
||||
assert "labels" in item
|
||||
assert item["labels"].shape == (block_size,)
|
||||
|
||||
def test_streaming_textdataset_padding_and_truncation():
|
||||
texts = ["short", "one two three four five six seven eight nine ten"]
|
||||
tokenizer = DummyTokenizer(40)
|
||||
block_size = 4
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
# короткое предложение padded
|
||||
assert (ds[0]["input_ids"].shape[0] == block_size)
|
||||
# длинное предложение truncated
|
||||
assert (ds[1]["input_ids"].shape[0] == block_size)
|
||||
|
||||
def test_streaming_textdataset_index_error():
|
||||
texts = ["sample"]
|
||||
tokenizer = DummyTokenizer(10)
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
|
||||
with pytest.raises(IndexError):
|
||||
_ = ds[1]
|
||||
|
||||
def test_streaming_textdataset_content_matching():
|
||||
texts = ["foo bar baz", "abc def"]
|
||||
tokenizer = DummyTokenizer(99)
|
||||
block_size = 5
|
||||
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||
# Проверка, что input_ids и labels совпадают точно
|
||||
for i in range(len(ds)):
|
||||
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])
|
||||
49
llm/tests/datasets/test_text_dataset.py
Normal file
49
llm/tests/datasets/test_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.text_dataset import TextDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self, vocab_size=100):
|
||||
self.vocab_size = vocab_size
|
||||
def encode(self, text, **kwargs):
|
||||
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||
|
||||
def test_textdataset_shape_and_basic():
|
||||
texts = ["hello world", "this is a test", "Transformer model"]
|
||||
tokenizer = DummyTokenizer(50)
|
||||
block_size = 6
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
for i in range(len(dataset)):
|
||||
x = dataset[i]
|
||||
assert isinstance(x, dict)
|
||||
assert "input_ids" in x
|
||||
assert isinstance(x["input_ids"], torch.Tensor)
|
||||
assert x["input_ids"].shape == (block_size,)
|
||||
|
||||
def test_textdataset_truncation_and_padding():
|
||||
texts = ["one two three four five six seven", "short"]
|
||||
tokenizer = DummyTokenizer(100)
|
||||
block_size = 5
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
assert isinstance(dataset[0], dict)
|
||||
assert dataset[0]["input_ids"].shape[0] == 5
|
||||
assert dataset[1]["input_ids"].shape[0] == 5
|
||||
|
||||
def test_textdataset_index_error():
|
||||
texts = ["a", "b"]
|
||||
tokenizer = DummyTokenizer(10)
|
||||
dataset = TextDataset(texts, tokenizer, block_size=3)
|
||||
with pytest.raises(IndexError):
|
||||
_ = dataset[2]
|
||||
|
||||
def test_textdataset_encoding():
|
||||
texts = ["привет", "мир"]
|
||||
tokenizer = DummyTokenizer(20)
|
||||
block_size = 4
|
||||
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||
assert len(dataset) == 2
|
||||
x = dataset[0]
|
||||
assert isinstance(x, dict)
|
||||
assert "input_ids" in x
|
||||
assert isinstance(x["input_ids"], torch.Tensor)
|
||||
assert x["input_ids"].shape == (block_size,)
|
||||
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import torch
|
||||
import pytest
|
||||
from llm.datasets.text_with_special_tokens_dataset import TextWithSpecialTokensDataset
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self):
|
||||
self.bos_token_id = 101
|
||||
self.eos_token_id = 102
|
||||
self.pad_token_id = 0
|
||||
def encode(self, text, add_special_tokens=False, add_bos_token=False, add_eos_token=False):
|
||||
ids = [ord(c) % 50 for c in text.strip()]
|
||||
if add_bos_token:
|
||||
ids = [self.bos_token_id] + ids
|
||||
if add_eos_token:
|
||||
ids = ids + [self.eos_token_id]
|
||||
return ids
|
||||
|
||||
def test_specialtokens_basic_bos_eos():
|
||||
texts = ["abc", "d"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 6
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||
for i in range(len(ds)):
|
||||
item = ds[i]
|
||||
assert isinstance(item, dict)
|
||||
assert "input_ids" in item
|
||||
assert item["input_ids"].shape == (block_size,)
|
||||
assert item["input_ids"][0] == tokenizer.bos_token_id
|
||||
assert item["input_ids"][item["input_ids"].ne(tokenizer.pad_token_id).sum() - 1] == tokenizer.eos_token_id
|
||||
|
||||
def test_specialtokens_padding_and_truncation():
|
||||
texts = ["qwertyuiop", "z"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 5
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True)
|
||||
assert ds[0]["input_ids"].shape[0] == block_size
|
||||
assert ds[1]["input_ids"][-1] == tokenizer.pad_token_id
|
||||
|
||||
def test_specialtokens_no_bos_eos():
|
||||
texts = ["xyz"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 6
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=False, add_eos=False)
|
||||
item = ds[0]["input_ids"]
|
||||
assert tokenizer.bos_token_id not in item
|
||||
assert tokenizer.eos_token_id not in item
|
||||
assert item.shape == (block_size,)
|
||||
|
||||
def test_specialtokens_index_error():
|
||||
texts = ["sample"]
|
||||
tokenizer = DummyTokenizer()
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=8)
|
||||
with pytest.raises(IndexError):
|
||||
_ = ds[1]
|
||||
|
||||
def test_specialtokens_labels():
|
||||
texts = ["abcd"]
|
||||
tokenizer = DummyTokenizer()
|
||||
block_size = 7
|
||||
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||
item = ds[0]
|
||||
assert torch.equal(item["input_ids"], item["labels"])
|
||||
Reference in New Issue
Block a user