mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 13:48:01 +00:00
Реализация MultiHeadAttention
- Добавлен класс MultiHeadAttention - Создана документация с блок-схемой - Добавлен пример использования - Обновлен README.md
This commit is contained in:
73
example/multi_head_attention_example.py
Normal file
73
example/multi_head_attention_example.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Пример использования MultiHeadAttention с визуализацией весов внимания
|
||||
"""
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
from simple_llm.transformer.multi_head_attention import MultiHeadAttention
|
||||
|
||||
def plot_attention_heads(weights, num_heads, filename="multi_head_attention.png"):
|
||||
"""Визуализация матриц внимания для всех голов"""
|
||||
fig, axes = plt.subplots(1, num_heads, figsize=(20, 5))
|
||||
for i in range(num_heads):
|
||||
ax = axes[i]
|
||||
img = ax.imshow(weights[0, i].detach().numpy(), cmap='viridis')
|
||||
ax.set_title(f'Голова {i+1}')
|
||||
ax.set_xlabel('Key позиции')
|
||||
ax.set_ylabel('Query позиции')
|
||||
fig.colorbar(img, ax=ax)
|
||||
plt.suptitle('Матрицы внимания по головам')
|
||||
plt.tight_layout()
|
||||
|
||||
# Создаем папку если нет
|
||||
os.makedirs('example_output', exist_ok=True)
|
||||
plt.savefig(f'example_output/{filename}')
|
||||
plt.close()
|
||||
|
||||
def main():
|
||||
# Конфигурация
|
||||
num_heads = 4
|
||||
emb_size = 64
|
||||
head_size = 16
|
||||
seq_len = 10
|
||||
batch_size = 1
|
||||
|
||||
# Инициализация
|
||||
torch.manual_seed(42)
|
||||
mha = MultiHeadAttention(
|
||||
num_heads=num_heads,
|
||||
emb_size=emb_size,
|
||||
head_size=head_size,
|
||||
max_seq_len=20
|
||||
)
|
||||
|
||||
# Тестовые данные
|
||||
x = torch.randn(batch_size, seq_len, emb_size)
|
||||
|
||||
# Прямой проход
|
||||
output = mha(x)
|
||||
# Получаем веса из всех голов
|
||||
attn_weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
||||
|
||||
print("Входная форма:", x.shape)
|
||||
print("Выходная форма:", output.shape)
|
||||
print("Форма весов внимания:", attn_weights.shape)
|
||||
|
||||
# Визуализация
|
||||
plot_attention_heads(attn_weights, num_heads)
|
||||
|
||||
# Демонстрация работы механизма
|
||||
print("\nПример с ручными зависимостями:")
|
||||
x = torch.zeros(batch_size, 3, emb_size)
|
||||
x[:, 1, :] = 2.0 # Яркий токен
|
||||
|
||||
# Получаем веса внимания
|
||||
weights = torch.stack([head.get_attention_weights(x) for head in mha._heads], dim=1)
|
||||
for head in range(num_heads):
|
||||
print(f"Голова {head+1} веса для токена 1:",
|
||||
weights[0, head, 1].detach().round(decimals=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("\nГотово! Графики сохранены в example_output/multi_head_attention.png")
|
||||
Reference in New Issue
Block a user