mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
74 lines
2.6 KiB
Python
74 lines
2.6 KiB
Python
|
|
"""
|
|||
|
|
Пример использования MultiHeadAttention с визуализацией весов внимания
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import torch
|
|||
|
|
import matplotlib.pyplot as plt
|
|||
|
|
import os
|
|||
|
|
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
|||
|
|
|
|||
|
|
def plot_attention_heads(weights, num_heads, filename="multi_head_attention.png"):
|
|||
|
|
"""Визуализация матриц внимания для всех голов"""
|
|||
|
|
fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))
|
|||
|
|
for i in range(num_heads):
|
|||
|
|
ax = axes[i]
|
|||
|
|
img = ax.imshow(weights[0, i].detach().numpy(), cmap='viridis')
|
|||
|
|
ax.set_title(f'Голова {i+1}')
|
|||
|
|
ax.set_xlabel('Key позиции')
|
|||
|
|
ax.set_ylabel('Query позиции')
|
|||
|
|
fig.colorbar(img, ax=ax)
|
|||
|
|
plt.suptitle('Матрицы внимания по головам')
|
|||
|
|
plt.tight_layout()
|
|||
|
|
|
|||
|
|
# Создаем папку если нет
|
|||
|
|
os.makedirs('example_output', exist_ok=True)
|
|||
|
|
plt.savefig(f'example_output/{filename}')
|
|||
|
|
plt.close()
|
|||
|
|
|
|||
|
|
def main():
|
|||
|
|
# Конфигурация
|
|||
|
|
num_heads = 4
|
|||
|
|
emb_size = 64
|
|||
|
|
head_size = 16
|
|||
|
|
seq_len = 10
|
|||
|
|
batch_size = 1
|
|||
|
|
|
|||
|
|
# Инициализация
|
|||
|
|
torch.manual_seed(42)
|
|||
|
|
mha = MultiHeadAttention(
|
|||
|
|
num_heads=num_heads,
|
|||
|
|
emb_size=emb_size,
|
|||
|
|
head_size=head_size,
|
|||
|
|
max_seq_len=20
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Тестовые данные
|
|||
|
|
x = torch.randn(batch_size, seq_len, emb_size)
|
|||
|
|
|
|||
|
|
# Прямой проход
|
|||
|
|
output = mha(x)
|
|||
|
|
# Получаем веса из всех голов
|
|||
|
|
attn_weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
|||
|
|
|
|||
|
|
print("Входная форма:", x.shape)
|
|||
|
|
print("Выходная форма:", output.shape)
|
|||
|
|
print("Форма весов внимания:", attn_weights.shape)
|
|||
|
|
|
|||
|
|
# Визуализация
|
|||
|
|
plot_attention_heads(attn_weights, num_heads)
|
|||
|
|
|
|||
|
|
# Демонстрация работы механизма
|
|||
|
|
print("\nПример с ручными зависимостями:")
|
|||
|
|
x = torch.zeros(batch_size, 3, emb_size)
|
|||
|
|
x[:, 1, :] = 2.0 # Яркий токен
|
|||
|
|
|
|||
|
|
# Получаем веса внимания
|
|||
|
|
weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
|||
|
|
for head in range(num_heads):
|
|||
|
|
print(f"Голова {head+1} веса для токена 1:",
|
|||
|
|
weights[0, head, 1].detach().round(decimals=3))
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
main()
|
|||
|
|
print("\nГотово! Графики сохранены в example_output/multi_head_attention.png")
|