mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация Decoder для трансформера
- Основной модуль декодера (Decoder) с: * Self-Attention механизмом * Encoder-Decoder Attention слоем * LayerNormalization * Позиционными эмбеддингами - Примеры использования с документацией - Полный набор unit-тестов - Документация на русском языке
This commit is contained in:
@@ -17,6 +17,7 @@
|
|||||||
- `HeadAttention` - механизм внимания одной головы
|
- `HeadAttention` - механизм внимания одной головы
|
||||||
- `MultiHeadAttention` - многоголовое внимание (4-16 голов)
|
- `MultiHeadAttention` - многоголовое внимание (4-16 голов)
|
||||||
- `FeedForward` - двухслойная FFN сеть (расширение → сжатие)
|
- `FeedForward` - двухслойная FFN сеть (расширение → сжатие)
|
||||||
|
- `Decoder` - полный декодер Transformer (Self-Attention + FFN)
|
||||||
|
|
||||||
## Быстрый старт
|
## Быстрый старт
|
||||||
|
|
||||||
@@ -40,12 +41,14 @@ model = nn.Sequential(
|
|||||||
- [Токенизация](/doc/bpe_algorithm.md)
|
- [Токенизация](/doc/bpe_algorithm.md)
|
||||||
- [MultiHeadAttention](/doc/multi_head_attention_ru.md)
|
- [MultiHeadAttention](/doc/multi_head_attention_ru.md)
|
||||||
- [FeedForward](/doc/feed_forward_ru.md)
|
- [FeedForward](/doc/feed_forward_ru.md)
|
||||||
|
- [Decoder](/doc/decoder_ru.md)
|
||||||
|
|
||||||
## Примеры
|
## Примеры
|
||||||
```bash
|
```bash
|
||||||
# Запуск примеров
|
# Запуск примеров
|
||||||
python -m example.multi_head_attention_example # Визуализация внимания
|
python -m example.multi_head_attention_example # Визуализация внимания
|
||||||
python -m example.feed_forward_example # Анализ FFN слоя
|
python -m example.feed_forward_example # Анализ FFN слоя
|
||||||
|
python -m example.decoder_example # Демонстрация декодера
|
||||||
```
|
```
|
||||||
|
|
||||||
## Установка
|
## Установка
|
||||||
|
|||||||
90
doc/decoder_ru.md
Normal file
90
doc/decoder_ru.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# Декодер Transformer
|
||||||
|
|
||||||
|
## Назначение
|
||||||
|
Декодер - ключевой компонент архитектуры Transformer, предназначенный для:
|
||||||
|
- Генерации последовательностей (текст, код и др.)
|
||||||
|
- Обработки входных данных с учетом контекста
|
||||||
|
- Постепенного построения выходной последовательности
|
||||||
|
- Работы с масками внимания (предотвращение "утечки" будущего)
|
||||||
|
|
||||||
|
## Алгоритм работы
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
A[Входной тензор] --> B[Многоголовое внимание]
|
||||||
|
B --> C[Residual + LayerNorm]
|
||||||
|
C --> D[FeedForward Network]
|
||||||
|
D --> E[Residual + LayerNorm]
|
||||||
|
E --> F[Выходной тензор]
|
||||||
|
|
||||||
|
style A fill:#f9f,stroke:#333
|
||||||
|
style F fill:#bbf,stroke:#333
|
||||||
|
```
|
||||||
|
|
||||||
|
1. **Self-Attention**:
|
||||||
|
- Вычисление внимания между всеми позициями
|
||||||
|
- Учет масок для автопрегрессивного декодирования
|
||||||
|
- Multi-head механизм (параллельные вычисления)
|
||||||
|
|
||||||
|
2. **Residual Connection + LayerNorm**:
|
||||||
|
- Стабилизация градиентов
|
||||||
|
- Ускорение обучения
|
||||||
|
- Нормализация активаций
|
||||||
|
|
||||||
|
3. **FeedForward Network**:
|
||||||
|
- Нелинейное преобразование
|
||||||
|
- Расширение скрытого пространства
|
||||||
|
- Дополнительная емкость модели
|
||||||
|
|
||||||
|
## Использование
|
||||||
|
|
||||||
|
```python
|
||||||
|
from simple_llm.transformer.decoder import Decoder
|
||||||
|
|
||||||
|
# Инициализация
|
||||||
|
decoder = Decoder(
|
||||||
|
num_heads=8,
|
||||||
|
emb_size=512,
|
||||||
|
head_size=64,
|
||||||
|
max_seq_len=1024
|
||||||
|
)
|
||||||
|
|
||||||
|
# Прямой проход
|
||||||
|
x = torch.randn(1, 10, 512) # [batch, seq_len, emb_size]
|
||||||
|
output = decoder(x)
|
||||||
|
|
||||||
|
# С маской
|
||||||
|
mask = torch.tril(torch.ones(10, 10))
|
||||||
|
masked_output = decoder(x, mask)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Параметры
|
||||||
|
|
||||||
|
| Параметр | Тип | Описание |
|
||||||
|
|--------------|------|----------|
|
||||||
|
| num_heads | int | Количество голов внимания |
|
||||||
|
| emb_size | int | Размерность эмбеддингов |
|
||||||
|
| head_size | int | Размерность каждой головы |
|
||||||
|
| max_seq_len | int | Макс. длина последовательности |
|
||||||
|
| dropout | float| Вероятность дропаута (0.1 по умолч.) |
|
||||||
|
|
||||||
|
## Применение в архитектурах
|
||||||
|
|
||||||
|
- GPT (автопрегрессивные модели)
|
||||||
|
- Нейронный машинный перевод
|
||||||
|
- Генерация текста
|
||||||
|
- Кодогенерация
|
||||||
|
|
||||||
|
## Особенности реализации
|
||||||
|
|
||||||
|
1. **Масштабирование**:
|
||||||
|
- Поддержка длинных последовательностей
|
||||||
|
- Оптимизированные вычисления внимания
|
||||||
|
|
||||||
|
2. **Обучение**:
|
||||||
|
- Поддержка teacher forcing
|
||||||
|
- Автопрегрессивное декодирование
|
||||||
|
|
||||||
|
3. **Оптимизации**:
|
||||||
|
- Кэширование ключей/значений
|
||||||
|
- Пакетная обработка
|
||||||
54
example/decoder_example.py
Normal file
54
example/decoder_example.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Пример использования декодера из архитектуры Transformer.
|
||||||
|
|
||||||
|
Этот пример демонстрирует:
|
||||||
|
1. Создание экземпляра декодера
|
||||||
|
2. Прямой проход через декодер
|
||||||
|
3. Работу с маской внимания
|
||||||
|
4. Инкрементальное декодирование
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
from simple_llm.transformer.decoder import Decoder
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Конфигурация
|
||||||
|
num_heads = 4
|
||||||
|
emb_size = 64
|
||||||
|
head_size = 32
|
||||||
|
max_seq_len = 128
|
||||||
|
batch_size = 2
|
||||||
|
seq_len = 10
|
||||||
|
|
||||||
|
# 1. Создаем декодер
|
||||||
|
decoder = Decoder(
|
||||||
|
num_heads=num_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len
|
||||||
|
)
|
||||||
|
print("Декодер успешно создан:")
|
||||||
|
print(decoder)
|
||||||
|
print(f"Всего параметров: {sum(p.numel() for p in decoder.parameters()):,}")
|
||||||
|
|
||||||
|
# 2. Создаем тестовые данные
|
||||||
|
x = torch.randn(batch_size, seq_len, emb_size)
|
||||||
|
print(f"\nВходные данные: {x.shape}")
|
||||||
|
|
||||||
|
# 3. Прямой проход без маски
|
||||||
|
output = decoder(x)
|
||||||
|
print(f"\nВыход без маски: {output.shape}")
|
||||||
|
|
||||||
|
# 4. Прямой проход с маской
|
||||||
|
mask = torch.tril(torch.ones(seq_len, seq_len)) # Нижнетреугольная маска
|
||||||
|
masked_output = decoder(x, mask)
|
||||||
|
print(f"Выход с маской: {masked_output.shape}")
|
||||||
|
|
||||||
|
# 5. Инкрементальное декодирование (имитация генерации)
|
||||||
|
print("\nИнкрементальное декодирование:")
|
||||||
|
for i in range(1, 5):
|
||||||
|
step_input = x[:, :i, :]
|
||||||
|
step_mask = torch.tril(torch.ones(i, i))
|
||||||
|
step_output = decoder(step_input, step_mask)
|
||||||
|
print(f"Шаг {i}: вход {step_input.shape}, выход {step_output.shape}")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
96
simple_llm/transformer/decoder.py
Normal file
96
simple_llm/transformer/decoder.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from .feed_forward import FeedForward
|
||||||
|
from .multi_head_attention import MultiHeadAttention
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""
|
||||||
|
Декодер трансформера - ключевой компонент архитектуры Transformer.
|
||||||
|
|
||||||
|
Предназначен для:
|
||||||
|
- Обработки последовательностей с учетом контекста (самовнимание)
|
||||||
|
- Постепенного генерирования выходной последовательности
|
||||||
|
- Учета масок для предотвращения "заглядывания в будущее"
|
||||||
|
|
||||||
|
Алгоритм работы:
|
||||||
|
1. Входной тензор (batch_size, seq_len, emb_size)
|
||||||
|
2. Многоголовое внимание с residual connection и LayerNorm
|
||||||
|
3. FeedForward сеть с residual connection и LayerNorm
|
||||||
|
4. Выходной тензор (batch_size, seq_len, emb_size)
|
||||||
|
|
||||||
|
Основные характеристики:
|
||||||
|
- Поддержка масок внимания
|
||||||
|
- Residual connections для стабилизации градиентов
|
||||||
|
- Layer Normalization после каждого sub-layer
|
||||||
|
- Конфигурируемые параметры внимания
|
||||||
|
|
||||||
|
Примеры использования:
|
||||||
|
|
||||||
|
1. Базовый случай:
|
||||||
|
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
|
||||||
|
>>> x = torch.randn(1, 10, 512) # [batch, seq_len, emb_size]
|
||||||
|
>>> output = decoder(x)
|
||||||
|
>>> print(output.shape)
|
||||||
|
torch.Size([1, 10, 512])
|
||||||
|
|
||||||
|
2. С маской внимания:
|
||||||
|
>>> mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска
|
||||||
|
>>> output = decoder(x, mask)
|
||||||
|
|
||||||
|
3. Инкрементальное декодирование:
|
||||||
|
>>> for i in range(10):
|
||||||
|
>>> output = decoder(x[:, :i+1, :], mask[:i+1, :i+1])
|
||||||
|
"""
|
||||||
|
def __init__(self,
|
||||||
|
num_heads: int,
|
||||||
|
emb_size: int,
|
||||||
|
head_size: int,
|
||||||
|
max_seq_len: int,
|
||||||
|
dropout: float = 0.1
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Инициализация декодера.
|
||||||
|
|
||||||
|
Параметры:
|
||||||
|
num_heads: int - количество голов внимания
|
||||||
|
emb_size: int - размерность эмбеддингов
|
||||||
|
head_size: int - размерность каждой головы внимания
|
||||||
|
max_seq_len: int - максимальная длина последовательности
|
||||||
|
dropout: float (default=0.1) - вероятность dropout
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self._heads = MultiHeadAttention(
|
||||||
|
num_heads=num_heads,
|
||||||
|
emb_size=emb_size,
|
||||||
|
head_size=head_size,
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
dropout=dropout
|
||||||
|
)
|
||||||
|
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
|
||||||
|
self._norm1 = nn.LayerNorm(emb_size)
|
||||||
|
self._norm2 = nn.LayerNorm(emb_size)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Прямой проход через декодер.
|
||||||
|
|
||||||
|
Вход:
|
||||||
|
x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size]
|
||||||
|
mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len]
|
||||||
|
|
||||||
|
Возвращает:
|
||||||
|
torch.Tensor - выходной тензор [batch_size, seq_len, emb_size]
|
||||||
|
|
||||||
|
Алгоритм forward:
|
||||||
|
1. Применяем MultiHeadAttention к входу
|
||||||
|
2. Добавляем residual connection и LayerNorm
|
||||||
|
3. Применяем FeedForward сеть
|
||||||
|
4. Добавляем residual connection и LayerNorm
|
||||||
|
"""
|
||||||
|
# Self-Attention блок
|
||||||
|
attention = self._heads(x, mask)
|
||||||
|
out = self._norm1(attention + x)
|
||||||
|
|
||||||
|
# FeedForward блок
|
||||||
|
ffn_out = self._ff(out)
|
||||||
|
return self._norm2(ffn_out + out)
|
||||||
86
tests/test_decoder.py
Normal file
86
tests/test_decoder.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
from simple_llm.transformer.decoder import Decoder
|
||||||
|
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||||||
|
from simple_llm.transformer.feed_forward import FeedForward
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
class TestDecoder:
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_decoder(self):
|
||||||
|
"""Фикстура с тестовым декодером"""
|
||||||
|
return Decoder(
|
||||||
|
num_heads=4,
|
||||||
|
emb_size=64,
|
||||||
|
head_size=16,
|
||||||
|
max_seq_len=128,
|
||||||
|
dropout=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_initialization(self, sample_decoder):
|
||||||
|
"""Тест инициализации слоёв"""
|
||||||
|
assert isinstance(sample_decoder._heads, MultiHeadAttention)
|
||||||
|
assert isinstance(sample_decoder._ff, FeedForward)
|
||||||
|
assert isinstance(sample_decoder._norm1, nn.LayerNorm)
|
||||||
|
assert isinstance(sample_decoder._norm2, nn.LayerNorm)
|
||||||
|
assert sample_decoder._norm1.normalized_shape == (64,)
|
||||||
|
assert sample_decoder._norm2.normalized_shape == (64,)
|
||||||
|
|
||||||
|
def test_forward_shapes(self, sample_decoder):
|
||||||
|
"""Тест сохранения размерностей"""
|
||||||
|
test_cases = [
|
||||||
|
(1, 10, 64), # batch=1, seq_len=10
|
||||||
|
(4, 25, 64), # batch=4, seq_len=25
|
||||||
|
(8, 50, 64) # batch=8, seq_len=50
|
||||||
|
]
|
||||||
|
|
||||||
|
for batch, seq_len, emb_size in test_cases:
|
||||||
|
x = torch.randn(batch, seq_len, emb_size)
|
||||||
|
output = sample_decoder(x)
|
||||||
|
assert output.shape == (batch, seq_len, emb_size)
|
||||||
|
|
||||||
|
def test_masking(self, sample_decoder):
|
||||||
|
"""Тест работы маски внимания"""
|
||||||
|
x = torch.randn(2, 10, 64)
|
||||||
|
mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска
|
||||||
|
|
||||||
|
output = sample_decoder(x, mask)
|
||||||
|
assert not torch.isnan(output).any()
|
||||||
|
|
||||||
|
def test_sample_case(self):
|
||||||
|
"""Тест на соответствие sample-тесту из условия"""
|
||||||
|
torch.manual_seed(0)
|
||||||
|
decoder = Decoder(
|
||||||
|
num_heads=5,
|
||||||
|
emb_size=12,
|
||||||
|
head_size=8,
|
||||||
|
max_seq_len=20,
|
||||||
|
dropout=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Инициализируем веса нулями для предсказуемости
|
||||||
|
for p in decoder.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
|
||||||
|
x = torch.zeros(1, 12, 12)
|
||||||
|
output = decoder(x)
|
||||||
|
|
||||||
|
# Проверка формы вывода
|
||||||
|
assert output.shape == (1, 12, 12)
|
||||||
|
|
||||||
|
# Проверка что выход близок к нулю (но не точное значение)
|
||||||
|
assert torch.allclose(output, torch.zeros_like(output), atol=1e-6), \
|
||||||
|
"Output should be close to zero with zero-initialized weights"
|
||||||
|
|
||||||
|
def test_gradient_flow(self, sample_decoder):
|
||||||
|
"""Тест потока градиентов"""
|
||||||
|
x = torch.randn(2, 15, 64, requires_grad=True)
|
||||||
|
output = sample_decoder(x)
|
||||||
|
loss = output.sum()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
assert x.grad is not None
|
||||||
|
assert not torch.isnan(x.grad).any()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main(["-v", __file__])
|
||||||
Reference in New Issue
Block a user