mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Добавление механизма внимания HeadAttention
- Реализация одного головного внимания из Transformer - Полная документация на русском языке - Пример использования с визуализацией - Обновление README с ссылками
This commit is contained in:
88
example/head_attention_example.py
Normal file
88
example/head_attention_example.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Пример использования механизма внимания HeadAttention
|
||||
с визуализацией матрицы внимания и анализом работы.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from simple_llm.transformer.head_attention import HeadAttention
|
||||
|
||||
import os
|
||||
os.makedirs("example_output", exist_ok=True)
|
||||
|
||||
def plot_attention(weights, tokens=None, filename="attention_plot.png"):
|
||||
"""Сохранение матрицы внимания в файл"""
|
||||
plt.figure(figsize=(10, 8))
|
||||
plt.imshow(weights, cmap='viridis')
|
||||
|
||||
if tokens:
|
||||
plt.xticks(range(len(tokens)), tokens, rotation=90)
|
||||
plt.yticks(range(len(tokens)), tokens)
|
||||
|
||||
plt.colorbar()
|
||||
plt.title("Матрица весов внимания")
|
||||
plt.xlabel("Key Positions")
|
||||
plt.ylabel("Query Positions")
|
||||
plt.savefig(f"example_output/{filename}")
|
||||
plt.close()
|
||||
|
||||
def simulate_text_attention():
|
||||
"""Пример с имитацией текстовых данных"""
|
||||
# Параметры
|
||||
emb_size = 64
|
||||
head_size = 32
|
||||
seq_len = 8
|
||||
|
||||
# Имитация токенов
|
||||
tokens = ["[CLS]", "мама", "мыла", "раму", ",", "папа", "пил", "какао"]
|
||||
|
||||
# Инициализация
|
||||
torch.manual_seed(42)
|
||||
attention = HeadAttention(emb_size, head_size, max_seq_len=seq_len)
|
||||
|
||||
# Случайные эмбеддинги (в реальности - выход слоя токенизации)
|
||||
x = torch.randn(1, seq_len, emb_size)
|
||||
|
||||
# Прямой проход + получение весов
|
||||
with torch.no_grad():
|
||||
output = attention(x)
|
||||
q, k = attention._q(x), attention._k(x)
|
||||
scores = (q @ k.transpose(-2, -1)) / np.sqrt(head_size)
|
||||
weights = torch.softmax(scores, dim=-1).squeeze()
|
||||
|
||||
# Визуализация
|
||||
print("\nПример для фразы:", " ".join(tokens))
|
||||
print("Форма выходного тензора:", output.shape)
|
||||
plot_attention(weights.numpy(), tokens)
|
||||
|
||||
def technical_demo():
|
||||
"""Техническая демонстрация работы механизма"""
|
||||
print("\nТехническая демонстрация HeadAttention")
|
||||
attention = HeadAttention(emb_size=16, head_size=8, max_seq_len=10)
|
||||
|
||||
# Создаем тензор с ручными значениями для анализа
|
||||
x = torch.zeros(1, 4, 16)
|
||||
x[0, 0, :] = 1.0 # Яркий токен
|
||||
x[0, 3, :] = 0.5 # Слабый токен
|
||||
|
||||
# Анализ весов
|
||||
with torch.no_grad():
|
||||
output = attention(x)
|
||||
q = attention._q(x)
|
||||
k = attention._k(x)
|
||||
print("\nQuery векторы (первые 5 значений):")
|
||||
print(q[0, :, :5])
|
||||
|
||||
print("\nKey векторы (первые 5 значений):")
|
||||
print(k[0, :, :5])
|
||||
|
||||
weights = torch.softmax((q @ k.transpose(-2, -1)) / np.sqrt(8), dim=-1)
|
||||
print("\nМатрица внимания:")
|
||||
print(weights.squeeze().round(decimals=3))
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Демонстрация работы HeadAttention")
|
||||
simulate_text_attention()
|
||||
technical_demo()
|
||||
print("\nГотово! Проверьте графики матрицы внимания.")
|
||||
Reference in New Issue
Block a user