Files
simple-llm/example/multi_head_attention_example.py

74 lines
2.6 KiB
Python
Raw Permalink Normal View History

"""
Пример использования MultiHeadAttention с визуализацией весов внимания
"""
import torch
import matplotlib.pyplot as plt
import os
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
def plot_attention_heads(weights, num_heads, filename="multi_head_attention.png"):
"""Визуализация матриц внимания для всех голов"""
fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))
for i in range(num_heads):
ax = axes[i]
img = ax.imshow(weights[0, i].detach().numpy(), cmap='viridis')
ax.set_title(f'Голова {i+1}')
ax.set_xlabel('Key позиции')
ax.set_ylabel('Query позиции')
fig.colorbar(img, ax=ax)
plt.suptitle('Матрицы внимания по головам')
plt.tight_layout()
# Создаем папку если нет
os.makedirs('example_output', exist_ok=True)
plt.savefig(f'example_output/{filename}')
plt.close()
def main():
# Конфигурация
num_heads = 4
emb_size = 64
head_size = 16
seq_len = 10
batch_size = 1
# Инициализация
torch.manual_seed(42)
mha = MultiHeadAttention(
num_heads=num_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=20
)
# Тестовые данные
x = torch.randn(batch_size, seq_len, emb_size)
# Прямой проход
output = mha(x)
# Получаем веса из всех голов
attn_weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
print("Входная форма:", x.shape)
print("Выходная форма:", output.shape)
print("Форма весов внимания:", attn_weights.shape)
# Визуализация
plot_attention_heads(attn_weights, num_heads)
# Демонстрация работы механизма
print("\nПример с ручными зависимостями:")
x = torch.zeros(batch_size, 3, emb_size)
x[:, 1, :] = 2.0 # Яркий токен
# Получаем веса внимания
weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
for head in range(num_heads):
print(f"Голова {head+1} веса для токена 1:",
weights[0, head, 1].detach().round(decimals=3))
if __name__ == "__main__":
main()
print("\nГотово! Графики сохранены в example_output/multi_head_attention.png")