mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация Decoder для трансформера
- Основной модуль декодера (Decoder) с: * Self-Attention механизмом * Encoder-Decoder Attention слоем * LayerNormalization * Позиционными эмбеддингами - Примеры использования с документацией - Полный набор unit-тестов - Документация на русском языке
This commit is contained in:
96
simple_llm/transformer/decoder.py
Normal file
96
simple_llm/transformer/decoder.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
from .feed_forward import FeedForward
|
||||
from .multi_head_attention import MultiHeadAttention
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""
|
||||
Декодер трансформера - ключевой компонент архитектуры Transformer.
|
||||
|
||||
Предназначен для:
|
||||
- Обработки последовательностей с учетом контекста (самовнимание)
|
||||
- Постепенного генерирования выходной последовательности
|
||||
- Учета масок для предотвращения "заглядывания в будущее"
|
||||
|
||||
Алгоритм работы:
|
||||
1. Входной тензор (batch_size, seq_len, emb_size)
|
||||
2. Многоголовое внимание с residual connection и LayerNorm
|
||||
3. FeedForward сеть с residual connection и LayerNorm
|
||||
4. Выходной тензор (batch_size, seq_len, emb_size)
|
||||
|
||||
Основные характеристики:
|
||||
- Поддержка масок внимания
|
||||
- Residual connections для стабилизации градиентов
|
||||
- Layer Normalization после каждого sub-layer
|
||||
- Конфигурируемые параметры внимания
|
||||
|
||||
Примеры использования:
|
||||
|
||||
1. Базовый случай:
|
||||
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||
>>> x = torch.randn(1, 10, 512) # [batch, seq_len, emb_size]
|
||||
>>> output = decoder(x)
|
||||
>>> print(output.shape)
|
||||
torch.Size([1, 10, 512])
|
||||
|
||||
2. С маской внимания:
|
||||
>>> mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска
|
||||
>>> output = decoder(x, mask)
|
||||
|
||||
3. Инкрементальное декодирование:
|
||||
>>> for i in range(10):
|
||||
>>> output = decoder(x[:, :i+1, :], mask[:i+1, :i+1])
|
||||
"""
|
||||
def __init__(self,
|
||||
num_heads: int,
|
||||
emb_size: int,
|
||||
head_size: int,
|
||||
max_seq_len: int,
|
||||
dropout: float = 0.1
|
||||
):
|
||||
"""
|
||||
Инициализация декодера.
|
||||
|
||||
Параметры:
|
||||
num_heads: int - количество голов внимания
|
||||
emb_size: int - размерность эмбеддингов
|
||||
head_size: int - размерность каждой головы внимания
|
||||
max_seq_len: int - максимальная длина последовательности
|
||||
dropout: float (default=0.1) - вероятность dropout
|
||||
"""
|
||||
super().__init__()
|
||||
self._heads = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len,
|
||||
dropout=dropout
|
||||
)
|
||||
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
|
||||
self._norm1 = nn.LayerNorm(emb_size)
|
||||
self._norm2 = nn.LayerNorm(emb_size)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
||||
"""
|
||||
Прямой проход через декодер.
|
||||
|
||||
Вход:
|
||||
x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size]
|
||||
mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len]
|
||||
|
||||
Возвращает:
|
||||
torch.Tensor - выходной тензор [batch_size, seq_len, emb_size]
|
||||
|
||||
Алгоритм forward:
|
||||
1. Применяем MultiHeadAttention к входу
|
||||
2. Добавляем residual connection и LayerNorm
|
||||
3. Применяем FeedForward сеть
|
||||
4. Добавляем residual connection и LayerNorm
|
||||
"""
|
||||
# Self-Attention блок
|
||||
attention = self._heads(x, mask)
|
||||
out = self._norm1(attention + x)
|
||||
|
||||
# FeedForward блок
|
||||
ffn_out = self._ff(out)
|
||||
return self._norm2(ffn_out + out)
|
||||
Reference in New Issue
Block a user