mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Реализация Decoder для трансформера
- Основной модуль декодера (Decoder) с: * Self-Attention механизмом * Encoder-Decoder Attention слоем * LayerNormalization * Позиционными эмбеддингами - Примеры использования с документацией - Полный набор unit-тестов - Документация на русском языке
This commit is contained in:
54
example/decoder_example.py
Normal file
54
example/decoder_example.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Пример использования декодера из архитектуры Transformer.
|
||||
|
||||
Этот пример демонстрирует:
|
||||
1. Создание экземпляра декодера
|
||||
2. Прямой проход через декодер
|
||||
3. Работу с маской внимания
|
||||
4. Инкрементальное декодирование
|
||||
"""
|
||||
import torch
|
||||
from simple_llm.transformer.decoder import Decoder
|
||||
|
||||
def main():
|
||||
# Конфигурация
|
||||
num_heads = 4
|
||||
emb_size = 64
|
||||
head_size = 32
|
||||
max_seq_len = 128
|
||||
batch_size = 2
|
||||
seq_len = 10
|
||||
|
||||
# 1. Создаем декодер
|
||||
decoder = Decoder(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=max_seq_len
|
||||
)
|
||||
print("Декодер успешно создан:")
|
||||
print(decoder)
|
||||
print(f"Всего параметров: {sum(p.numel() for p in decoder.parameters()):,}")
|
||||
|
||||
# 2. Создаем тестовые данные
|
||||
x = torch.randn(batch_size, seq_len, emb_size)
|
||||
print(f"\nВходные данные: {x.shape}")
|
||||
|
||||
# 3. Прямой проход без маски
|
||||
output = decoder(x)
|
||||
print(f"\nВыход без маски: {output.shape}")
|
||||
|
||||
# 4. Прямой проход с маской
|
||||
mask = torch.tril(torch.ones(seq_len, seq_len)) # Нижнетреугольная маска
|
||||
masked_output = decoder(x, mask)
|
||||
print(f"Выход с маской: {masked_output.shape}")
|
||||
|
||||
# 5. Инкрементальное декодирование (имитация генерации)
|
||||
print("\nИнкрементальное декодирование:")
|
||||
for i in range(1, 5):
|
||||
step_input = x[:, :i, :]
|
||||
step_mask = torch.tril(torch.ones(i, i))
|
||||
step_output = decoder(step_input, step_mask)
|
||||
print(f"Шаг {i}: вход {step_input.shape}, выход {step_output.shape}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user