diff --git a/README.md b/README.md index 272f4d0..ca35087 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ - `HeadAttention` - механизм внимания одной головы - `MultiHeadAttention` - многоголовое внимание (4-16 голов) - `FeedForward` - двухслойная FFN сеть (расширение → сжатие) +- `Decoder` - полный декодер Transformer (Self-Attention + FFN) ## Быстрый старт @@ -40,12 +41,14 @@ model = nn.Sequential( - [Токенизация](/doc/bpe_algorithm.md) - [MultiHeadAttention](/doc/multi_head_attention_ru.md) - [FeedForward](/doc/feed_forward_ru.md) +- [Decoder](/doc/decoder_ru.md) ## Примеры ```bash # Запуск примеров python -m example.multi_head_attention_example # Визуализация внимания python -m example.feed_forward_example # Анализ FFN слоя +python -m example.decoder_example # Демонстрация декодера ``` ## Установка diff --git a/doc/decoder_ru.md b/doc/decoder_ru.md new file mode 100644 index 0000000..32327c3 --- /dev/null +++ b/doc/decoder_ru.md @@ -0,0 +1,90 @@ +# Декодер Transformer + +## Назначение +Декодер - ключевой компонент архитектуры Transformer, предназначенный для: +- Генерации последовательностей (текст, код и др.) +- Обработки входных данных с учетом контекста +- Постепенного построения выходной последовательности +- Работы с масками внимания (предотвращение "утечки" будущего) + +## Алгоритм работы + +```mermaid +graph TD + A[Входной тензор] --> B[Многоголовое внимание] + B --> C[Residual + LayerNorm] + C --> D[FeedForward Network] + D --> E[Residual + LayerNorm] + E --> F[Выходной тензор] + + style A fill:#f9f,stroke:#333 + style F fill:#bbf,stroke:#333 +``` + +1. **Self-Attention**: + - Вычисление внимания между всеми позициями + - Учет масок для автопрегрессивного декодирования + - Multi-head механизм (параллельные вычисления) + +2. **Residual Connection + LayerNorm**: + - Стабилизация градиентов + - Ускорение обучения + - Нормализация активаций + +3. **FeedForward Network**: + - Нелинейное преобразование + - Расширение скрытого пространства + - Дополнительная емкость модели + +## Использование + +```python +from simple_llm.transformer.decoder import Decoder + +# Инициализация +decoder = Decoder( + num_heads=8, + emb_size=512, + head_size=64, + max_seq_len=1024 +) + +# Прямой проход +x = torch.randn(1, 10, 512) # [batch, seq_len, emb_size] +output = decoder(x) + +# С маской +mask = torch.tril(torch.ones(10, 10)) +masked_output = decoder(x, mask) +``` + +## Параметры + +| Параметр | Тип | Описание | +|--------------|------|----------| +| num_heads | int | Количество голов внимания | +| emb_size | int | Размерность эмбеддингов | +| head_size | int | Размерность каждой головы | +| max_seq_len | int | Макс. длина последовательности | +| dropout | float| Вероятность дропаута (0.1 по умолч.) | + +## Применение в архитектурах + +- GPT (автопрегрессивные модели) +- Нейронный машинный перевод +- Генерация текста +- Кодогенерация + +## Особенности реализации + +1. **Масштабирование**: + - Поддержка длинных последовательностей + - Оптимизированные вычисления внимания + +2. **Обучение**: + - Поддержка teacher forcing + - Автопрегрессивное декодирование + +3. **Оптимизации**: + - Кэширование ключей/значений + - Пакетная обработка diff --git a/example/decoder_example.py b/example/decoder_example.py new file mode 100644 index 0000000..4c68f83 --- /dev/null +++ b/example/decoder_example.py @@ -0,0 +1,54 @@ +"""Пример использования декодера из архитектуры Transformer. + +Этот пример демонстрирует: +1. Создание экземпляра декодера +2. Прямой проход через декодер +3. Работу с маской внимания +4. Инкрементальное декодирование +""" +import torch +from simple_llm.transformer.decoder import Decoder + +def main(): + # Конфигурация + num_heads = 4 + emb_size = 64 + head_size = 32 + max_seq_len = 128 + batch_size = 2 + seq_len = 10 + + # 1. Создаем декодер + decoder = Decoder( + num_heads=num_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len + ) + print("Декодер успешно создан:") + print(decoder) + print(f"Всего параметров: {sum(p.numel() for p in decoder.parameters()):,}") + + # 2. Создаем тестовые данные + x = torch.randn(batch_size, seq_len, emb_size) + print(f"\nВходные данные: {x.shape}") + + # 3. Прямой проход без маски + output = decoder(x) + print(f"\nВыход без маски: {output.shape}") + + # 4. Прямой проход с маской + mask = torch.tril(torch.ones(seq_len, seq_len)) # Нижнетреугольная маска + masked_output = decoder(x, mask) + print(f"Выход с маской: {masked_output.shape}") + + # 5. Инкрементальное декодирование (имитация генерации) + print("\nИнкрементальное декодирование:") + for i in range(1, 5): + step_input = x[:, :i, :] + step_mask = torch.tril(torch.ones(i, i)) + step_output = decoder(step_input, step_mask) + print(f"Шаг {i}: вход {step_input.shape}, выход {step_output.shape}") + +if __name__ == "__main__": + main() diff --git a/simple_llm/transformer/decoder.py b/simple_llm/transformer/decoder.py new file mode 100644 index 0000000..3ce5cdc --- /dev/null +++ b/simple_llm/transformer/decoder.py @@ -0,0 +1,96 @@ +from torch import nn +import torch +from .feed_forward import FeedForward +from .multi_head_attention import MultiHeadAttention + +class Decoder(nn.Module): + """ + Декодер трансформера - ключевой компонент архитектуры Transformer. + + Предназначен для: + - Обработки последовательностей с учетом контекста (самовнимание) + - Постепенного генерирования выходной последовательности + - Учета масок для предотвращения "заглядывания в будущее" + + Алгоритм работы: + 1. Входной тензор (batch_size, seq_len, emb_size) + 2. Многоголовое внимание с residual connection и LayerNorm + 3. FeedForward сеть с residual connection и LayerNorm + 4. Выходной тензор (batch_size, seq_len, emb_size) + + Основные характеристики: + - Поддержка масок внимания + - Residual connections для стабилизации градиентов + - Layer Normalization после каждого sub-layer + - Конфигурируемые параметры внимания + + Примеры использования: + + 1. Базовый случай: + >>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024) + >>> x = torch.randn(1, 10, 512) # [batch, seq_len, emb_size] + >>> output = decoder(x) + >>> print(output.shape) + torch.Size([1, 10, 512]) + + 2. С маской внимания: + >>> mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска + >>> output = decoder(x, mask) + + 3. Инкрементальное декодирование: + >>> for i in range(10): + >>> output = decoder(x[:, :i+1, :], mask[:i+1, :i+1]) + """ + def __init__(self, + num_heads: int, + emb_size: int, + head_size: int, + max_seq_len: int, + dropout: float = 0.1 + ): + """ + Инициализация декодера. + + Параметры: + num_heads: int - количество голов внимания + emb_size: int - размерность эмбеддингов + head_size: int - размерность каждой головы внимания + max_seq_len: int - максимальная длина последовательности + dropout: float (default=0.1) - вероятность dropout + """ + super().__init__() + self._heads = MultiHeadAttention( + num_heads=num_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len, + dropout=dropout + ) + self._ff = FeedForward(emb_size=emb_size, dropout=dropout) + self._norm1 = nn.LayerNorm(emb_size) + self._norm2 = nn.LayerNorm(emb_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Прямой проход через декодер. + + Вход: + x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size] + mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len] + + Возвращает: + torch.Tensor - выходной тензор [batch_size, seq_len, emb_size] + + Алгоритм forward: + 1. Применяем MultiHeadAttention к входу + 2. Добавляем residual connection и LayerNorm + 3. Применяем FeedForward сеть + 4. Добавляем residual connection и LayerNorm + """ + # Self-Attention блок + attention = self._heads(x, mask) + out = self._norm1(attention + x) + + # FeedForward блок + ffn_out = self._ff(out) + return self._norm2(ffn_out + out) \ No newline at end of file diff --git a/tests/test_decoder.py b/tests/test_decoder.py new file mode 100644 index 0000000..4a43c92 --- /dev/null +++ b/tests/test_decoder.py @@ -0,0 +1,86 @@ +import torch +import pytest +from simple_llm.transformer.decoder import Decoder +from simple_llm.transformer.multi_head_attention import MultiHeadAttention +from simple_llm.transformer.feed_forward import FeedForward +from torch import nn + +class TestDecoder: + @pytest.fixture + def sample_decoder(self): + """Фикстура с тестовым декодером""" + return Decoder( + num_heads=4, + emb_size=64, + head_size=16, + max_seq_len=128, + dropout=0.1 + ) + + def test_initialization(self, sample_decoder): + """Тест инициализации слоёв""" + assert isinstance(sample_decoder._heads, MultiHeadAttention) + assert isinstance(sample_decoder._ff, FeedForward) + assert isinstance(sample_decoder._norm1, nn.LayerNorm) + assert isinstance(sample_decoder._norm2, nn.LayerNorm) + assert sample_decoder._norm1.normalized_shape == (64,) + assert sample_decoder._norm2.normalized_shape == (64,) + + def test_forward_shapes(self, sample_decoder): + """Тест сохранения размерностей""" + test_cases = [ + (1, 10, 64), # batch=1, seq_len=10 + (4, 25, 64), # batch=4, seq_len=25 + (8, 50, 64) # batch=8, seq_len=50 + ] + + for batch, seq_len, emb_size in test_cases: + x = torch.randn(batch, seq_len, emb_size) + output = sample_decoder(x) + assert output.shape == (batch, seq_len, emb_size) + + def test_masking(self, sample_decoder): + """Тест работы маски внимания""" + x = torch.randn(2, 10, 64) + mask = torch.tril(torch.ones(10, 10)) # Нижнетреугольная маска + + output = sample_decoder(x, mask) + assert not torch.isnan(output).any() + + def test_sample_case(self): + """Тест на соответствие sample-тесту из условия""" + torch.manual_seed(0) + decoder = Decoder( + num_heads=5, + emb_size=12, + head_size=8, + max_seq_len=20, + dropout=0.0 + ) + + # Инициализируем веса нулями для предсказуемости + for p in decoder.parameters(): + nn.init.zeros_(p) + + x = torch.zeros(1, 12, 12) + output = decoder(x) + + # Проверка формы вывода + assert output.shape == (1, 12, 12) + + # Проверка что выход близок к нулю (но не точное значение) + assert torch.allclose(output, torch.zeros_like(output), atol=1e-6), \ + "Output should be close to zero with zero-initialized weights" + + def test_gradient_flow(self, sample_decoder): + """Тест потока градиентов""" + x = torch.randn(2, 15, 64, requires_grad=True) + output = sample_decoder(x) + loss = output.sum() + loss.backward() + + assert x.grad is not None + assert not torch.isnan(x.grad).any() + +if __name__ == "__main__": + pytest.main(["-v", __file__])