Files
simple-llm/simple_llm/embedding/token_embeddings.py

58 lines
2.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from torch import nn
from torch import Tensor
class TokenEmbeddings(nn.Module):
"""
Модуль PyTorch для преобразования индексов токенов в векторные представления (эмбеддинги).
Преобразует целочисленные индексы токенов в обучаемые векторные представления фиксированного размера.
Обычно используется как первый слой в нейронных сетях для задач NLP.
Аргументы:
vocab_size (int): Размер словаря (количество уникальных токенов)
emb_size (int): Размерность векторных представлений
Форматы данных:
- Вход: тензор (batch_size, seq_len) индексов токенов
- Выход: тензор (batch_size, seq_len, emb_size) векторных представлений
Примеры использования:
>>> embedding_layer = TokenEmbeddings(vocab_size=10000, emb_size=256)
>>> tokens = torch.tensor([[1, 2, 3], [4, 5, 6]]) # batch_size=2, seq_len=3
>>> embeddings = embedding_layer(tokens)
>>> embeddings.shape
torch.Size([2, 3, 256])
Примечание:
- Индексы должны быть в диапазоне [0, vocab_size-1]
- Эмбеддинги инициализируются случайно и обучаются в процессе тренировки модели
"""
def __init__(self, vocab_size: int, emb_size: int):
super().__init__()
self._embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_size
)
def forward(self, x: Tensor) -> Tensor:
return self._embedding(x)
if __name__ == "__main__":
# Пример использования
embedding = TokenEmbeddings(vocab_size=100, emb_size=128)
# Создаем тензор с индексами в пределах vocab_size (0-99)
tensor = torch.tensor([
[11, 45, 76, 34],
[34, 67, 45, 54]
])
# Проверяем индексы
if (tensor >= 100).any():
raise ValueError("Some indices are out of vocabulary range (vocab_size=100)")
output = embedding(tensor)
print("Embeddings shape:", output.shape)
print(f"{output.shape} | {output.mean().item():.11f}") # Формат как в ТЗ