mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
|
|
"""Пример использования декодера из архитектуры Transformer.
|
|||
|
|
|
|||
|
|
Этот пример демонстрирует:
|
|||
|
|
1. Создание экземпляра декодера
|
|||
|
|
2. Прямой проход через декодер
|
|||
|
|
3. Работу с маской внимания
|
|||
|
|
4. Инкрементальное декодирование
|
|||
|
|
"""
|
|||
|
|
import torch
|
|||
|
|
from simple_llm.transformer.decoder import Decoder
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# Конфигурация
|
|||
|
|
num_heads = 4
|
|||
|
|
emb_size = 64
|
|||
|
|
head_size = 32
|
|||
|
|
max_seq_len = 128
|
|||
|
|
batch_size = 2
|
|||
|
|
seq_len = 10
|
|||
|
|
|
|||
|
|
# 1. Создаем декодер
|
|||
|
|
decoder = Decoder(
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
emb_size=emb_size,
|
|||
|
|
head_size=head_size,
|
|||
|
|
max_seq_len=max_seq_len
|
|||
|
|
)
|
|||
|
|
print("Декодер успешно создан:")
|
|||
|
|
print(decoder)
|
|||
|
|
print(f"Всего параметров: {sum(p.numel() for p in decoder.parameters()):,}")
|
|||
|
|
|
|||
|
|
# 2. Создаем тестовые данные
|
|||
|
|
x = torch.randn(batch_size, seq_len, emb_size)
|
|||
|
|
print(f"\nВходные данные: {x.shape}")
|
|||
|
|
|
|||
|
|
# 3. Прямой проход без маски
|
|||
|
|
output = decoder(x)
|
|||
|
|
print(f"\nВыход без маски: {output.shape}")
|
|||
|
|
|
|||
|
|
# 4. Прямой проход с маской
|
|||
|
|
mask = torch.tril(torch.ones(seq_len, seq_len)) # Нижнетреугольная маска
|
|||
|
|
masked_output = decoder(x, mask)
|
|||
|
|
print(f"Выход с маской: {masked_output.shape}")
|
|||
|
|
|
|||
|
|
# 5. Инкрементальное декодирование (имитация генерации)
|
|||
|
|
print("\nИнкрементальное декодирование:")
|
|||
|
|
for i in range(1, 5):
|
|||
|
|
step_input = x[:, :i, :]
|
|||
|
|
step_mask = torch.tril(torch.ones(i, i))
|
|||
|
|
step_output = decoder(step_input, step_mask)
|
|||
|
|
print(f"Шаг {i}: вход {step_input.shape}, выход {step_output.shape}")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|