mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
doc(datasets): update docstrings and tests
This commit is contained in:
@@ -16,7 +16,7 @@ import torch
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from llm.tokenizers import BPETokenizer
|
from llm.tokenizers import BPETokenizer
|
||||||
from llm.training.dataset import TextDataset
|
from llm.datasets.text_dataset import TextDataset
|
||||||
from llm.training.trainer import Trainer
|
from llm.training.trainer import Trainer
|
||||||
|
|
||||||
from shared.data import (
|
from shared.data import (
|
||||||
|
|||||||
0
llm/src/llm/datasets/__init__.py
Normal file
0
llm/src/llm/datasets/__init__.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
120
llm/src/llm/datasets/streaming_text_dataset.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingTextDataset(Dataset):
|
||||||
|
"""
|
||||||
|
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
|
||||||
|
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
|
||||||
|
- Поддерживает стандартный DataLoader PyTorch.
|
||||||
|
|
||||||
|
Ключевые особенности:
|
||||||
|
---------------------
|
||||||
|
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
|
||||||
|
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
|
||||||
|
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
|
||||||
|
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
|
||||||
|
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
|
||||||
|
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
|
||||||
|
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
|
||||||
|
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
|
||||||
|
>>> for batch in loader:
|
||||||
|
... print(batch['input_ids'].shape) # torch.Size([8, 512])
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
|
||||||
|
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
|
||||||
|
- Не работает с файлами напрямую, только со строками (списком).
|
||||||
|
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
|
||||||
|
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
|
||||||
|
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||||
|
"""
|
||||||
|
Инициализация StreamingTextDataset из списка строк.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
|
||||||
|
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
|
||||||
|
- Каждый пример автоматически дополняется или усекается до block_size.
|
||||||
|
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
|
||||||
|
>>> for ex in ds:
|
||||||
|
... print(ex['input_ids'].shape) # torch.Size([256])
|
||||||
|
"""
|
||||||
|
self.texts = texts
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
# Получаем pad_token_id из токенизатора
|
||||||
|
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество доступных примеров в датасете.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Число примеров (равно длине исходного списка строк).
|
||||||
|
"""
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить обработанный пример по индексу из потокового датасета.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
idx (int): Индекс примера в исходном списке строк.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
dict: Словарь с тензорами для обучения LLM:
|
||||||
|
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
|
||||||
|
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> item = dataset[10]
|
||||||
|
>>> assert isinstance(item, dict)
|
||||||
|
>>> assert item['input_ids'].shape == (block_size,)
|
||||||
|
>>> assert 'labels' in item
|
||||||
|
"""
|
||||||
|
text = self.texts[idx]
|
||||||
|
|
||||||
|
# Токенизация на лету
|
||||||
|
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Обрезаем или дополняем до нужной длины
|
||||||
|
if len(input_ids) > self.block_size:
|
||||||
|
input_ids = input_ids[: self.block_size]
|
||||||
|
else:
|
||||||
|
input_ids = input_ids + [self.pad_token_id] * (
|
||||||
|
self.block_size - len(input_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
112
llm/src/llm/datasets/text_dataset.py
Normal file
112
llm/src/llm/datasets/text_dataset.py
Normal file
@@ -0,0 +1,112 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
|
||||||
|
|
||||||
|
class TextDataset(Dataset):
|
||||||
|
"""
|
||||||
|
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
|
||||||
|
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
|
||||||
|
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
|
||||||
|
|
||||||
|
Формат и аргументы конструктора:
|
||||||
|
-------------------------------
|
||||||
|
texts: List[str]
|
||||||
|
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
|
||||||
|
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
|
||||||
|
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
|
||||||
|
block_size: int, по умолчанию 128
|
||||||
|
Желаемая длина выходной последовательности (padding/truncation внутри класса).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
|
||||||
|
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
|
||||||
|
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
|
||||||
|
|
||||||
|
Пример использования:
|
||||||
|
---------------------
|
||||||
|
>>> with open("dataset.txt", encoding="utf-8") as f:
|
||||||
|
... texts = f.read().splitlines()
|
||||||
|
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
|
||||||
|
>>> from torch.utils.data import DataLoader
|
||||||
|
>>> loader = DataLoader(dataset, batch_size=4)
|
||||||
|
>>> for item in loader:
|
||||||
|
... # item['input_ids'] для обучения LLM
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- Torch Dataset: https://pytorch.org/docs/stable/data.html
|
||||||
|
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
||||||
|
"""
|
||||||
|
Инициализация датасета из списка строк.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
|
||||||
|
block_size (int, по умолчанию 128): Желаемая длина результата —
|
||||||
|
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
- Строки не фильтруются и не изменяются внутри датасета.
|
||||||
|
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
|
||||||
|
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
|
||||||
|
"""
|
||||||
|
self.examples = []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Кодируем текст в токены
|
||||||
|
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
|
# Обрезаем или дополняем до нужной длины
|
||||||
|
if len(input_ids) > block_size:
|
||||||
|
input_ids = input_ids[:block_size]
|
||||||
|
else:
|
||||||
|
# Дополняем pad_token_id
|
||||||
|
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||||
|
|
||||||
|
self.examples.append(input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество примеров в датасете (длина списка текстов).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Число примеров в датасете.
|
||||||
|
"""
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить пример из датасета по индексу.
|
||||||
|
|
||||||
|
Аргументы:
|
||||||
|
idx (int): Индекс примера.
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
dict: Словарь с тензорами токенов для модели:
|
||||||
|
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
|
||||||
|
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
|
||||||
|
|
||||||
|
Пример:
|
||||||
|
>>> item = dataset[7]
|
||||||
|
>>> assert isinstance(item, dict)
|
||||||
|
>>> assert item['input_ids'].shape == (block_size,)
|
||||||
|
>>> assert 'labels' in item
|
||||||
|
"""
|
||||||
|
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
124
llm/src/llm/datasets/text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import torch
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from typing import List, Any
|
||||||
|
from llm.datasets.text_dataset import TextDataset
|
||||||
|
|
||||||
|
|
||||||
|
class TextWithSpecialTokensDataset(TextDataset):
|
||||||
|
"""
|
||||||
|
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
|
||||||
|
|
||||||
|
Назначение:
|
||||||
|
-----------
|
||||||
|
- Работает с уже готовым списком строк (не с файлом!).
|
||||||
|
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
|
||||||
|
- Обрезает или дополняет каждую последовательность до длины block_size.
|
||||||
|
|
||||||
|
Аргументы конструктора:
|
||||||
|
-----------------------
|
||||||
|
texts (List[str]): Список обучающих строк (примеров).
|
||||||
|
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
|
||||||
|
block_size (int, default=128): Желаемая длина примера (padding/truncation).
|
||||||
|
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
|
||||||
|
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
|
||||||
|
|
||||||
|
Особенности:
|
||||||
|
------------
|
||||||
|
- Если pad_token_id не задан — по умолчанию паддит нулями.
|
||||||
|
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
|
||||||
|
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
|
||||||
|
- Пример вызова:
|
||||||
|
>>> texts = ["пример текста", "ещё текст"]
|
||||||
|
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
|
||||||
|
>>> out = ds[0]
|
||||||
|
>>> assert out['input_ids'].shape == (16,)
|
||||||
|
|
||||||
|
References:
|
||||||
|
-----------
|
||||||
|
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
|
||||||
|
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
texts: List[str],
|
||||||
|
tokenizer: Any,
|
||||||
|
block_size: int = 128,
|
||||||
|
add_bos: bool = False,
|
||||||
|
add_eos: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Инициализация датасета с поддержкой специальных токенов.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts (List[str]): Список строк (все ваши обучающие примеры).
|
||||||
|
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
|
||||||
|
block_size (int): Длина выходного примера.
|
||||||
|
add_bos (bool): Добавлять ли BOS токен в начало.
|
||||||
|
add_eos (bool): Добавлять ли EOS токен в конец.
|
||||||
|
"""
|
||||||
|
self.examples = []
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.block_size = block_size
|
||||||
|
self.add_bos = add_bos
|
||||||
|
self.add_eos = add_eos
|
||||||
|
|
||||||
|
for text in texts:
|
||||||
|
# Кодируем с специальными токенами
|
||||||
|
input_ids = tokenizer.encode(
|
||||||
|
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
|
||||||
|
)
|
||||||
|
|
||||||
|
# Учитываем специальные токены при обрезке/дополнении
|
||||||
|
effective_block_size = block_size
|
||||||
|
if add_bos:
|
||||||
|
effective_block_size -= 1
|
||||||
|
if add_eos:
|
||||||
|
effective_block_size -= 1
|
||||||
|
|
||||||
|
if len(input_ids) > effective_block_size:
|
||||||
|
input_ids = input_ids[:effective_block_size]
|
||||||
|
|
||||||
|
# Добавляем специальные токены если нужно
|
||||||
|
if (
|
||||||
|
add_bos
|
||||||
|
and hasattr(tokenizer, "bos_token_id")
|
||||||
|
and tokenizer.bos_token_id is not None
|
||||||
|
):
|
||||||
|
input_ids = [tokenizer.bos_token_id] + input_ids
|
||||||
|
if (
|
||||||
|
add_eos
|
||||||
|
and hasattr(tokenizer, "eos_token_id")
|
||||||
|
and tokenizer.eos_token_id is not None
|
||||||
|
):
|
||||||
|
input_ids = input_ids + [tokenizer.eos_token_id]
|
||||||
|
|
||||||
|
# Дополняем до полной длины
|
||||||
|
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
||||||
|
if len(input_ids) < block_size:
|
||||||
|
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
||||||
|
|
||||||
|
self.examples.append(input_ids)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Возвращает количество примеров в датасете.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Размер (len(self.examples)).
|
||||||
|
"""
|
||||||
|
return len(self.examples)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
"""
|
||||||
|
Получить пример с учётом специальных токенов и паддинга.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
idx (int): Индекс в dataset.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
|
||||||
|
"""
|
||||||
|
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
||||||
|
labels = input_ids.clone()
|
||||||
|
return {"input_ids": input_ids, "labels": labels}
|
||||||
@@ -107,12 +107,12 @@ class Llama(BaseModel):
|
|||||||
) -> tuple:
|
) -> tuple:
|
||||||
"""
|
"""
|
||||||
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
|
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
|
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
|
||||||
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
|
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
|
||||||
cache (list or None): предыдущий кэш, если нужен
|
cache (list or None): предыдущий кэш, если нужен
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
logits: torch.Tensor [batch, seq_len, vocab_size]
|
logits: torch.Tensor [batch, seq_len, vocab_size]
|
||||||
new_cache: новый кэш attention (или None)
|
new_cache: новый кэш attention (или None)
|
||||||
@@ -178,25 +178,50 @@ class Llama(BaseModel):
|
|||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Генерация текста c помощью LLaMA (autoregressive Transformer).
|
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||||
Поддерживается:
|
|
||||||
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
|
Аргументы:
|
||||||
- кэш attention для ускорения генерации длинных последовательностей
|
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
|
||||||
|
max_new_tokens (int): Максимальное количество новых токенов для генерации.
|
||||||
Args:
|
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
|
||||||
x (Tensor[int]): начальная последовательность [batch, seq_len]
|
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
|
||||||
max_new_tokens (int): сколько новых токенов сгенерировать
|
>1.0 — менее предсказуемые, более разнообразные выборки.
|
||||||
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
|
<1.0 — более строгие, консервативные выборки.
|
||||||
temperature (float): масштаб для softmax (важно для sampling)
|
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
|
||||||
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
|
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
|
||||||
top_p (float|None): nucleus sampling
|
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
|
||||||
use_cache (bool): ускоряет autoregressive при длинной генерации
|
|
||||||
Returns:
|
Возвращает:
|
||||||
output (Tensor[int]): [batch, seq_len + max_new_tokens]
|
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
|
||||||
Пример:
|
|
||||||
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
|
Исключения:
|
||||||
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
|
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
|
||||||
>>> print(tokenizer.decode(generated[0]))
|
ValueError: Если temperature ≤ 0.
|
||||||
|
ValueError: Если одновременно заданы top_k и top_p.
|
||||||
|
ValueError: Если top_k ≤ 0.
|
||||||
|
ValueError: Если top_p не в диапазоне (0, 1].
|
||||||
|
|
||||||
|
Примеры:
|
||||||
|
>>> # Строго жадная генерация
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
|
||||||
|
>>> # Вероятностная генерация с температурой
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
|
||||||
|
>>> # Top-k sampling
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
|
||||||
|
>>> # Top-p (nucleus)
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
|
||||||
|
>>> # Комбинация температуры и top-k
|
||||||
|
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
|
||||||
|
|
||||||
|
Примечания:
|
||||||
|
- temperature, top_k, top_p применяются только если do_sample=True.
|
||||||
|
- Одновременное использование top_k и top_p запрещено.
|
||||||
|
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
|
||||||
|
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
|
||||||
|
|
||||||
|
Ссылки:
|
||||||
|
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
|
||||||
|
- LLaMA: https://arxiv.org/abs/2302.13971
|
||||||
"""
|
"""
|
||||||
cache = None
|
cache = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,155 +0,0 @@
|
|||||||
import torch
|
|
||||||
from torch.utils.data import Dataset
|
|
||||||
from typing import List, Any
|
|
||||||
|
|
||||||
|
|
||||||
class TextDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Простой датасет для языкового моделирования (LLM).
|
|
||||||
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
|
||||||
"""
|
|
||||||
Инициализация датасета.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: Список текстов для обучения
|
|
||||||
tokenizer: Токенизатор с методами encode/decode
|
|
||||||
block_size: Максимальная длина последовательности
|
|
||||||
"""
|
|
||||||
self.examples = []
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
# Кодируем текст в токены
|
|
||||||
input_ids = tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
# Обрезаем или дополняем до нужной длины
|
|
||||||
if len(input_ids) > block_size:
|
|
||||||
input_ids = input_ids[:block_size]
|
|
||||||
else:
|
|
||||||
# Дополняем pad_token_id
|
|
||||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
|
||||||
|
|
||||||
self.examples.append(input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.examples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
class StreamingTextDataset(Dataset):
|
|
||||||
"""
|
|
||||||
Датасет для потоковой обработки больших текстов.
|
|
||||||
Токенизация происходит на лету, что экономит память.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
|
|
||||||
self.texts = texts
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
|
|
||||||
# Получаем pad_token_id из токенизатора
|
|
||||||
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.texts)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
text = self.texts[idx]
|
|
||||||
|
|
||||||
# Токенизация на лету
|
|
||||||
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
|
||||||
|
|
||||||
# Обрезаем или дополняем до нужной длины
|
|
||||||
if len(input_ids) > self.block_size:
|
|
||||||
input_ids = input_ids[: self.block_size]
|
|
||||||
else:
|
|
||||||
input_ids = input_ids + [self.pad_token_id] * (
|
|
||||||
self.block_size - len(input_ids)
|
|
||||||
)
|
|
||||||
|
|
||||||
input_ids = torch.tensor(input_ids, dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
|
|
||||||
|
|
||||||
class TextDatasetWithSpecialTokens(TextDataset):
|
|
||||||
"""
|
|
||||||
Расширенная версия TextDataset с поддержкой специальных токенов.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
texts: List[str],
|
|
||||||
tokenizer: Any,
|
|
||||||
block_size: int = 128,
|
|
||||||
add_bos: bool = False,
|
|
||||||
add_eos: bool = False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
texts: Список текстов
|
|
||||||
tokenizer: Токенизатор
|
|
||||||
block_size: Максимальная длина
|
|
||||||
add_bos: Добавлять токен начала последовательности
|
|
||||||
add_eos: Добавлять токен конца последовательности
|
|
||||||
"""
|
|
||||||
self.examples = []
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.block_size = block_size
|
|
||||||
self.add_bos = add_bos
|
|
||||||
self.add_eos = add_eos
|
|
||||||
|
|
||||||
for text in texts:
|
|
||||||
# Кодируем с специальными токенами
|
|
||||||
input_ids = tokenizer.encode(
|
|
||||||
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=eos
|
|
||||||
)
|
|
||||||
|
|
||||||
# Учитываем специальные токены при обрезке/дополнении
|
|
||||||
effective_block_size = block_size
|
|
||||||
if add_bos:
|
|
||||||
effective_block_size -= 1
|
|
||||||
if add_eos:
|
|
||||||
effective_block_size -= 1
|
|
||||||
|
|
||||||
if len(input_ids) > effective_block_size:
|
|
||||||
input_ids = input_ids[:effective_block_size]
|
|
||||||
|
|
||||||
# Добавляем специальные токены если нужно
|
|
||||||
if (
|
|
||||||
add_bos
|
|
||||||
and hasattr(tokenizer, "bos_token_id")
|
|
||||||
and tokenizer.bos_token_id is not None
|
|
||||||
):
|
|
||||||
input_ids = [tokenizer.bos_token_id] + input_ids
|
|
||||||
if (
|
|
||||||
add_eos
|
|
||||||
and hasattr(tokenizer, "eos_token_id")
|
|
||||||
and tokenizer.eos_token_id is not None
|
|
||||||
):
|
|
||||||
input_ids = input_ids + [tokenizer.eos_token_id]
|
|
||||||
|
|
||||||
# Дополняем до полной длины
|
|
||||||
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
|
|
||||||
if len(input_ids) < block_size:
|
|
||||||
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
|
|
||||||
|
|
||||||
self.examples.append(input_ids)
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.examples)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
|
|
||||||
labels = input_ids.clone()
|
|
||||||
return {"input_ids": input_ids, "labels": labels}
|
|
||||||
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
49
llm/tests/datasets/test_streaming_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.streaming_text_dataset import StreamingTextDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self, vocab_size=100):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||||
|
|
||||||
|
def test_streaming_textdataset_basic_shape():
|
||||||
|
texts = ["hello world", "big transformers are fun", "LLM test string"]
|
||||||
|
tokenizer = DummyTokenizer(50)
|
||||||
|
block_size = 7
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
assert len(ds) == 3
|
||||||
|
for i in range(len(ds)):
|
||||||
|
item = ds[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert item["input_ids"].shape == (block_size,)
|
||||||
|
assert "labels" in item
|
||||||
|
assert item["labels"].shape == (block_size,)
|
||||||
|
|
||||||
|
def test_streaming_textdataset_padding_and_truncation():
|
||||||
|
texts = ["short", "one two three four five six seven eight nine ten"]
|
||||||
|
tokenizer = DummyTokenizer(40)
|
||||||
|
block_size = 4
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
# короткое предложение padded
|
||||||
|
assert (ds[0]["input_ids"].shape[0] == block_size)
|
||||||
|
# длинное предложение truncated
|
||||||
|
assert (ds[1]["input_ids"].shape[0] == block_size)
|
||||||
|
|
||||||
|
def test_streaming_textdataset_index_error():
|
||||||
|
texts = ["sample"]
|
||||||
|
tokenizer = DummyTokenizer(10)
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size=5)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = ds[1]
|
||||||
|
|
||||||
|
def test_streaming_textdataset_content_matching():
|
||||||
|
texts = ["foo bar baz", "abc def"]
|
||||||
|
tokenizer = DummyTokenizer(99)
|
||||||
|
block_size = 5
|
||||||
|
ds = StreamingTextDataset(texts, tokenizer, block_size)
|
||||||
|
# Проверка, что input_ids и labels совпадают точно
|
||||||
|
for i in range(len(ds)):
|
||||||
|
assert torch.equal(ds[i]["input_ids"], ds[i]["labels"])
|
||||||
49
llm/tests/datasets/test_text_dataset.py
Normal file
49
llm/tests/datasets/test_text_dataset.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.text_dataset import TextDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self, vocab_size=100):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
def encode(self, text, **kwargs):
|
||||||
|
return [len(w) % self.vocab_size for w in text.strip().split()]
|
||||||
|
|
||||||
|
def test_textdataset_shape_and_basic():
|
||||||
|
texts = ["hello world", "this is a test", "Transformer model"]
|
||||||
|
tokenizer = DummyTokenizer(50)
|
||||||
|
block_size = 6
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
for i in range(len(dataset)):
|
||||||
|
x = dataset[i]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert "input_ids" in x
|
||||||
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||||||
|
assert x["input_ids"].shape == (block_size,)
|
||||||
|
|
||||||
|
def test_textdataset_truncation_and_padding():
|
||||||
|
texts = ["one two three four five six seven", "short"]
|
||||||
|
tokenizer = DummyTokenizer(100)
|
||||||
|
block_size = 5
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
assert isinstance(dataset[0], dict)
|
||||||
|
assert dataset[0]["input_ids"].shape[0] == 5
|
||||||
|
assert dataset[1]["input_ids"].shape[0] == 5
|
||||||
|
|
||||||
|
def test_textdataset_index_error():
|
||||||
|
texts = ["a", "b"]
|
||||||
|
tokenizer = DummyTokenizer(10)
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=3)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = dataset[2]
|
||||||
|
|
||||||
|
def test_textdataset_encoding():
|
||||||
|
texts = ["привет", "мир"]
|
||||||
|
tokenizer = DummyTokenizer(20)
|
||||||
|
block_size = 4
|
||||||
|
dataset = TextDataset(texts, tokenizer, block_size=block_size)
|
||||||
|
assert len(dataset) == 2
|
||||||
|
x = dataset[0]
|
||||||
|
assert isinstance(x, dict)
|
||||||
|
assert "input_ids" in x
|
||||||
|
assert isinstance(x["input_ids"], torch.Tensor)
|
||||||
|
assert x["input_ids"].shape == (block_size,)
|
||||||
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
62
llm/tests/datasets/test_text_with_special_tokens_dataset.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from llm.datasets.text_with_special_tokens_dataset import TextWithSpecialTokensDataset
|
||||||
|
|
||||||
|
class DummyTokenizer:
|
||||||
|
def __init__(self):
|
||||||
|
self.bos_token_id = 101
|
||||||
|
self.eos_token_id = 102
|
||||||
|
self.pad_token_id = 0
|
||||||
|
def encode(self, text, add_special_tokens=False, add_bos_token=False, add_eos_token=False):
|
||||||
|
ids = [ord(c) % 50 for c in text.strip()]
|
||||||
|
if add_bos_token:
|
||||||
|
ids = [self.bos_token_id] + ids
|
||||||
|
if add_eos_token:
|
||||||
|
ids = ids + [self.eos_token_id]
|
||||||
|
return ids
|
||||||
|
|
||||||
|
def test_specialtokens_basic_bos_eos():
|
||||||
|
texts = ["abc", "d"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 6
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||||
|
for i in range(len(ds)):
|
||||||
|
item = ds[i]
|
||||||
|
assert isinstance(item, dict)
|
||||||
|
assert "input_ids" in item
|
||||||
|
assert item["input_ids"].shape == (block_size,)
|
||||||
|
assert item["input_ids"][0] == tokenizer.bos_token_id
|
||||||
|
assert item["input_ids"][item["input_ids"].ne(tokenizer.pad_token_id).sum() - 1] == tokenizer.eos_token_id
|
||||||
|
|
||||||
|
def test_specialtokens_padding_and_truncation():
|
||||||
|
texts = ["qwertyuiop", "z"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 5
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True)
|
||||||
|
assert ds[0]["input_ids"].shape[0] == block_size
|
||||||
|
assert ds[1]["input_ids"][-1] == tokenizer.pad_token_id
|
||||||
|
|
||||||
|
def test_specialtokens_no_bos_eos():
|
||||||
|
texts = ["xyz"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 6
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=False, add_eos=False)
|
||||||
|
item = ds[0]["input_ids"]
|
||||||
|
assert tokenizer.bos_token_id not in item
|
||||||
|
assert tokenizer.eos_token_id not in item
|
||||||
|
assert item.shape == (block_size,)
|
||||||
|
|
||||||
|
def test_specialtokens_index_error():
|
||||||
|
texts = ["sample"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=8)
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
_ = ds[1]
|
||||||
|
|
||||||
|
def test_specialtokens_labels():
|
||||||
|
texts = ["abcd"]
|
||||||
|
tokenizer = DummyTokenizer()
|
||||||
|
block_size = 7
|
||||||
|
ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=block_size, add_bos=True, add_eos=True)
|
||||||
|
item = ds[0]
|
||||||
|
assert torch.equal(item["input_ids"], item["labels"])
|
||||||
Reference in New Issue
Block a user