diff --git a/experiments/llm_only/configs/mixtral_generate.json b/experiments/llm_only/configs/mixtral_generate.json new file mode 100644 index 0000000..5fb07e7 --- /dev/null +++ b/experiments/llm_only/configs/mixtral_generate.json @@ -0,0 +1,19 @@ +{ + "bpe_tokenizer": "checkpoints/bpe_tokenizer.json", + "test_prompts": [ + "Open weights", + "The Llama model is", + "Efficient transformers" + ], + "model_config_path": "checkpoints/mixtral-bpe/config.json", + "model_weights": "checkpoints/mixtral-bpe/model.pt", + "generation": { + "max_new_tokens": 40, + "temperature": 0.8, + "do_sample": true, + "top_k": null, + "top_p": null + }, + "log_path": "checkpoints/mixtral_only_generation_logs.json" + } + \ No newline at end of file diff --git a/experiments/llm_only/configs/mixtral_train.json b/experiments/llm_only/configs/mixtral_train.json new file mode 100644 index 0000000..67ae9ae --- /dev/null +++ b/experiments/llm_only/configs/mixtral_train.json @@ -0,0 +1,28 @@ +{ + "bpe_tokenizer": "checkpoints/bpe_tokenizer.json", + "bpe_vocab_size": 1000, + "bpe_special_tokens": ["", "", "", ""], + "test_prompts": ["Open source AI", "What is Llama?"], + "model_config": { + "vocab_size": null, + "embed_dim": 256, + "num_q_heads": 4, + "num_kv_heads": 2, + "head_size": 64, + "num_layers": 4, + "max_position_embeddings": 512, + "num_experts": 8, + "top_k_experts": 2, + "window_size": 16, + "dropout": 0.1 + }, + "model_weights": "checkpoints/mixtral-bpe/model.pt", + "model_config_path": "checkpoints/mixtral-bpe/config.json", + "training": { + "learning_rate": 0.0003, + "batch_size": 2, + "num_epochs": 3, + "warmup_steps": 50 + }, + "log_path": "checkpoints/mixtral_only_training_logs.json" + } \ No newline at end of file diff --git a/experiments/llm_only/run_llm_experiment.py b/experiments/llm_only/run_llm_experiment.py index b59d240..105315e 100644 --- a/experiments/llm_only/run_llm_experiment.py +++ b/experiments/llm_only/run_llm_experiment.py @@ -45,6 +45,9 @@ def load_model_class(model_name): elif model_name.lower() == 'mistral': from llm.models.mistral import Mistral return Mistral + elif model_name.lower() == 'mixtral': + from llm.models.mixtral import Mixtral + return Mixtral else: raise ValueError(f"Модель '{model_name}' не поддерживается.") diff --git a/llm/src/llm/core/moe.py b/llm/src/llm/core/moe.py new file mode 100644 index 0000000..4964b22 --- /dev/null +++ b/llm/src/llm/core/moe.py @@ -0,0 +1,92 @@ +import torch +from torch import nn +import torch.nn.functional as F +from llm.core.swi_glu import SwiGLU + +class MoE(nn.Module): + def __init__( + self, + emb_size: int, + num_experts: int, + top_k_experts: int, + dropout: float = 0.1, + ): + super().__init__() + self._num_experts = num_experts + self._top_k_experts = top_k_experts + + self._router = nn.Linear(emb_size, num_experts) + self._experts = nn.ModuleList([SwiGLU( + emb_size=emb_size, + dropout=dropout, + ) for _ in range(num_experts)]) + self._dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor): + batch_size, seq_len, emb_size = x.shape + + # 1. Пропускаем через роутер + router_logits = self._router(x) # [batch_size, seq_len, num_experts] + + # 2. Отбираем топ-k экспертов для каждого токена + topk_logits, topk_indices = torch.topk( + router_logits, + k=self._top_k_experts, + dim=-1 + ) # topk_logits: [batch_size, seq_len, top_k] + # topk_indices: [batch_size, seq_len, top_k] + + # 3. Получаем веса через softmax и нормируем + topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k] + + # 4. Создаём нулевой тензор для результата + output = torch.zeros_like(x) # [batch_size, seq_len, emb_size] + + # 5. Проходим по всем экспертам + for expert_id in range(self._num_experts): + # Шаг 1: Создаём маску - где находится текущий эксперт в топ-k + expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k] + # Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном + if not expert_mask.any(): + continue # Эксперт никем не выбран, переходим к следующему + + # Шаг 3: Находим токены, которые выбрали этого эксперта + # (хотя бы в одной из top_k позиций) + token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len] + + # Шаг 4: Отбираем токены из x + # Отбираем токены для этого эксперта + expert_input = x[token_mask] + + # Пропускаем через эксперта + # Добавляем batch dimension для SwiGLU и затем убираем + expert_output = self._experts[expert_id]( + expert_input.unsqueeze(0) + ).squeeze(0) + + # Получаем веса для этого эксперта + # Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз) + # Но на практике каждый эксперт появляется максимум 1 раз в топ-k + # Находим веса: где expert_mask == True, берём соответствующий вес + weights_for_expert = torch.zeros( + batch_size, seq_len, device=x.device + ) + + # Для каждой позиции в топ-k + for k in range(self._top_k_experts): + mask_k = topk_indices[:, :, k] == expert_id + weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k] + + # Отбираем только веса для выбранных токенов + selected_weights = weights_for_expert[token_mask] # [num_selected_tokens] + + + # Перемножьте выход эксперта на веса текущего эксперта. + weighted_output = selected_weights.unsqueeze(-1) * expert_output + + # Помещаем результат на своё место в выходном тензоре + output[token_mask] += weighted_output + + out = self._dropout(output) + + return out \ No newline at end of file diff --git a/llm/src/llm/models/mixtral/__init__.py b/llm/src/llm/models/mixtral/__init__.py new file mode 100644 index 0000000..9b6cdbf --- /dev/null +++ b/llm/src/llm/models/mixtral/__init__.py @@ -0,0 +1,3 @@ +from .mixtral import Mixtral + +__all__ = ["Mixtral"] diff --git a/llm/src/llm/models/mixtral/mixtral.py b/llm/src/llm/models/mixtral/mixtral.py new file mode 100644 index 0000000..6992196 --- /dev/null +++ b/llm/src/llm/models/mixtral/mixtral.py @@ -0,0 +1,295 @@ +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F +from math import sqrt +from llm.core.base_model import BaseModel +from llm.core.token_embeddings import TokenEmbeddings +from llm.core.rope import RoPE +from llm.core.rms_norm import RMSNorm +from llm.core.moe import MoE +from llm.core.group_query_attention import GroupedQueryAttention + + +class Decoder(nn.Module): + def __init__(self, + num_q_heads: int, + num_kv_heads: int, + emb_size: int, + head_size: int, + max_seq_len: int, + num_experts: int, + top_k_experts: int, + window_size: int, + rope: RoPE, + dropout: float = 0.1 + ): + super().__init__() + self._heads = GroupedQueryAttention( + num_q_heads=num_q_heads, + num_kv_heads=num_kv_heads, + emb_size=emb_size, + head_size=head_size, + max_seq_len=max_seq_len, + window_size=window_size, + rope=rope, + dropout=dropout + ) + self._ff = MoE( + emb_size=emb_size, + num_experts=num_experts, + top_k_experts=top_k_experts, + dropout=dropout + ) + self._norm1 = RMSNorm(emb_size) + self._norm2 = RMSNorm(emb_size) + + def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor: + norm1_out = self._norm1(x) + attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache) + out = attention + x + + norm2_out = self._norm2(out) + ffn_out = self._ff(norm2_out) + + if use_cache is True: + return (ffn_out + out, kv_caches) + else: + return (ffn_out + out, None) + + + +from torch import nn +import torch +import torch.nn.functional as F + +class Mixtral(BaseModel): + def __init__(self, config): + super().__init__(config) + + self._max_seq_len = config["max_position_embeddings"] + + # Инициализация слоев + self._token_embeddings = TokenEmbeddings( + vocab_size=config["vocab_size"], + emb_size=config["embed_dim"] + ) + self._position_embeddings = RoPE( + head_size=config["embed_dim"] // config["num_q_heads"], + max_seq_len=config["max_position_embeddings"] + ) + #self._position_embeddings = PositionalEmbeddings( + # max_seq_len=max_seq_len, + # emb_size=emb_size + #) + self._dropout = nn.Dropout(config["dropout"]) + self._decoders = nn.ModuleList([Decoder( + num_q_heads=config["num_q_heads"], + num_kv_heads=config["num_kv_heads"], + emb_size=config["embed_dim"], + head_size=config["embed_dim"] // config["num_q_heads"], + max_seq_len=config["max_position_embeddings"], + num_experts=config["num_experts"], + top_k_experts=config["top_k_experts"], + window_size=config["window_size"], + rope=self._position_embeddings, + dropout=config["dropout"] + ) for _ in range(config["num_layers"])]) + self._norm = RMSNorm(config["embed_dim"]) + self._linear = nn.Linear(config["embed_dim"], config["vocab_size"]) + + def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple: + # Проверка длины последовательности (только при отсутствии кэша) + if cache is None and x.size(1) > self._max_seq_len: + raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}") + + # Эмбеддинги токенов и позиций + tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size] + #pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size] + + # Комбинирование + out = self._dropout(tok_out) # [batch, seq_len, emb_size] + + # Стек декодеров с передачей кэша + new_cache = [] + for i, decoder in enumerate(self._decoders): + decoder_cache = cache[i] if cache is not None else None + decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache) + + # Извлекаем результат из кортежа + if use_cache: + out, decoder_new_cache = decoder_result + new_cache.append(decoder_new_cache) + else: + out = decoder_result[0] + + out = self._norm(out) + logits = self._linear(out) + + # Возвращаем результат с учетом use_cache + if use_cache: + return (logits, new_cache) + else: + return (logits, None) + + def generate(self, + x: torch.Tensor, + max_new_tokens: int, + do_sample: bool, + temperature: float = 1.0, + top_k: int = None, + top_p: float = None, + use_cache: bool = True + ) -> torch.Tensor: + """ + Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling + и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах). + + Аргументы: + x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len]. + max_new_tokens (int): Максимальное количество новых токенов для генерации. + do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax). + temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие. + top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов. + top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p. + use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима. + + Возвращает: + torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens]. + + Исключения: + ValueError: Если x длиннее max_seq_len модели. + ValueError: Если temperature ≤ 0. + ValueError: Если одновременно заданы top_k и top_p. + ValueError: Если top_k ≤ 0. + ValueError: Если top_p не в диапазоне (0, 1]. + + Примеры: + >>> # Жадная генерация + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False) + >>> # Сэмплирование с температурой + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8) + >>> # Top-k sampling + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50) + >>> # Top-p (nucleus) sampling + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92) + >>> # Температура + top-k + >>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100) + + Примечания: + - Одновременно использовать top_k и top_p нельзя. + - Параметры temperature, top_k, top_p работают только при do_sample=True. + - Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed. + - Метод всегда возвращает только индексы токенов; для получения логитов используйте forward. + + Ссылки: + - Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751 + - Mistral: https://arxiv.org/abs/2310.06825 + """ + cache = None + + for _ in range(max_new_tokens): + if use_cache and cache is not None: + # Используем кэш - передаем только последний токен + x_input = x[:, -1:] # [batch_size, 1] + else: + # Первая итерация или кэш отключен - передаем всю последовательность + x_input = x + + # Прямой проход с кэшем + logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache) + + # Обновляем кэш для следующей итерации + if use_cache: + cache = new_cache + + last_logits = logits[:, -1, :] # [batch_size, vocab_size] + + # Масштабируем логиты температурой + if temperature > 0: + logits_scaled = last_logits / temperature + else: + logits_scaled = last_logits + + if do_sample == True and top_k != None: + _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1) + + # # Заменим все НЕ top-k логиты на -inf + masked_logits = logits_scaled.clone() + vocab_size = logits_scaled.size(-1) + + # создаём маску: 1, если токен НЕ в topk_indices + mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) + mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы + masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf') + + logits_scaled = masked_logits + + if do_sample == True and top_p != None: + # 1. Применим softmax, чтобы получить вероятности: + probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size] + # 2. Отсортируем токены по убыванию вероятностей: + sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) + # 3. Посчитаем кумулятивную сумму вероятностей: + cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size] + # 4. Определим маску: оставить токены, пока сумма < top_p + sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size] + # Гарантируем, что хотя бы первый токен останется + sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1 + # 5. Преобразуем маску обратно в оригинальный порядок: + # Создаём полную маску из 0 + mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8) + # Устанавливаем 1 в местах нужных токенов + mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask) + # 6. Зануляем логиты токенов вне топ-p: + logits_scaled[~mask] = float('-inf') + + # 4. Применяем Softmax + probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size] + + + if do_sample == True: + # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial + next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1] + else: + # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью + next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1] + + # 6. Добавляем его к последовательности + x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1] + return x + + + def save(self, path): + torch.save({ + 'model_state_dict': self.state_dict(), + 'vocab_size': self._vocab_size, + 'max_seq_len': self._max_seq_len, + 'emb_size': self._emb_size, + 'num_heads': self._num_heads, + 'head_size': self._head_size, + 'num_layers': self._num_layers + }, path) + + @classmethod + def load(cls, path, device): + checkpoint = torch.load(path, map_location=device) + model = cls( + vocab_size=checkpoint['vocab_size'], + max_seq_len=checkpoint['max_seq_len'], + emb_size=checkpoint['emb_size'], + num_heads=checkpoint['num_heads'], + head_size=checkpoint['head_size'], + num_layers=checkpoint['num_layers'] + ) + model.load_state_dict(checkpoint['model_state_dict']) + model.to(device) + return model + + @property + def max_seq_len(self) -> int: + return self._max_seq_len + + + + diff --git a/llm/tests/models/test_mixtral.py b/llm/tests/models/test_mixtral.py new file mode 100644 index 0000000..cc862fc --- /dev/null +++ b/llm/tests/models/test_mixtral.py @@ -0,0 +1,57 @@ +import torch +import pytest +from llm.models.mixtral.mixtral import Mixtral + +@pytest.fixture +def config(): + return { + "vocab_size": 100, + "embed_dim": 32, + "num_q_heads": 4, + "num_kv_heads": 2, + "num_layers": 2, + "max_position_embeddings": 16, + "window_size": 8, + "dropout": 0.0, + "num_experts": 4, + "top_k_experts": 2, + } + +@pytest.fixture +def model(config): + return Mixtral(config) + +def test_forward_basic(model): + x = torch.randint(0, 100, (2, 8)) + logits, cache = model(x) + assert logits.shape == (2, 8, 100) + assert isinstance(cache, list) + assert len(cache) == model._decoders.__len__() + +def test_forward_with_cache(model): + x = torch.randint(0, 100, (2, 4)) + logits, cache = model(x, use_cache=True) + x2 = torch.randint(0, 100, (2, 1)) + logits2, cache2 = model(x2, use_cache=True, cache=cache) + assert logits2.shape == (2, 1, 100) + assert isinstance(cache2, list) + +def test_generate_and_shape(model): + x = torch.randint(0, 100, (1, 5)) + result = model.generate(x, max_new_tokens=3, do_sample=False) + assert result.shape == (1, 8) + +def test_forward_sequence_too_long(model, config): + x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1)) + with pytest.raises(ValueError): + model(x) + +def test_generate_with_sampling_topk(model): + x = torch.randint(0, 100, (1, 3)) + out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5) + assert out.shape == (1, 5) + +def test_generate_with_sampling_topp(model): + x = torch.randint(0, 100, (1, 3)) + out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8) + assert out.shape == (1, 5) diff --git a/notebooks/mixstral.ipynb b/notebooks/mixstral.ipynb new file mode 100644 index 0000000..ebf0ade --- /dev/null +++ b/notebooks/mixstral.ipynb @@ -0,0 +1,1510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c3276440", + "metadata": {}, + "source": [ + "# **Mixtral**\n", + "\n", + "

\n", + " \"Mixtral\"\n", + "

\n", + "\n", + "Mixtral 8x7B вышел в декабре 2023.\n", + "\n", + "В нем сохранились все архитектурные особенности первого Mistral: RoPE, SwiGLU, GQA и SWA. И добавилась одна суперзначимая новинка — Mixture-of-Experts (MoE). Основное предназначение MoE — ускорение инференса для очень крупных моделей (с большим багажом знаний).\n", + "\n", + "Технически MoE меняет облик FeedForward слоя, добавляя в него экспертов — по сути, кучу тех же FeedForward сетей. Но эти эксперты работают не все сразу. Специальная подпрограмма — роутер — решает, кого из них и когда вызывать. Тем самым мы экономим на инференсе.\n", + "\n", + "В Mixtral 8x7B было 8 экспертов, но в единицу времени работали только 2 из них. Общий вес модели — 47B параметров, но из-за того, что задействовано только 2 эксперта, по скорости инференс получается как у примерно 13B моделей. А качество сравнимо с 70B моделями (например, Llama 2 70B).\n", + "\n", + "В дальнейшем эта новинка станет очень популярна в больших языковых моделях.\n" + ] + }, + { + "cell_type": "markdown", + "id": "e8dd2f09", + "metadata": {}, + "source": [ + "# Mixture-of-Experts\n", + "\n", + "

\n", + " \"MoE\"\n", + "

\n", + "\n", + "Mixture-of-Experts (MoE) — это архитектура нейронных сетей, которая позволяет масштабировать модели до огромных размеров (триллионы параметров) без пропорционального увеличения вычислительных затрат. Достигается это за счет замены FeedForward слоя на слой MoE.\n", + "\n", + "Стандартная полносвязная нейросеть (Feed-Forward Network, FFN) в трансформере обрабатывает каждый входной токен одним и тем же набором весов. В MoE эта единственная сеть заменяется набором «‎экспертов» — небольших независимых нейросетей (обычно тех же FNN). И при проходе для каждого токена система динамически выбирает, каким экспертом или экспертами его обработать. В результате мы получаем два основных выигрыша:\n", + "\n", + "* Снижение нагрузки на систему и, как следствие, ускорение инференса.\n", + "* Повышение когнитивных способностей системы: потому что в каждый момент времени над токеном работают наиболее подходящие для этого эксперты.\n", + "\n", + "> Идея с MoE рано или поздно должна была появиться. На FNN приходится примерно 2/3 всех параметров современных LLM — это очень существенно. И причина этого известна: в FNN хранится основная часть знаний модели. И попытка сэкономить на таком большом количестве параметров является очевидным шагом.\n", + "\n", + "### Алгоритм\n", + "\n", + "MoE состоит из трех основных компонентов: маршрутизатора, экспертов и процедуры взвешивания. Работают они следующим образом:\n", + "\n", + "* На вход поступает тензор `x`.\n", + "* Маршрутизатор (Router или Gating Network) — это небольшая нейросеть. Часто просто один линейный слой. Функция маршрутизатора — решать, каким экспертам обработать `x`. Для этого по каждому токену маршрутизатор выдает: топ-k экспертов и их веса.\n", + "* Эксперты (Experts) — это несколько идентичных по структуре, но разных по весам нейронных сетей. Обычно в их роли выступают стандартные FFN-слои. Эксперты параллельны и не зависят друг от друга. Каждый токен из `x` пропускается только через назначенных ему экспертов — эксперты возвращают результат.\n", + "* Выходной тензор — это взвешенная сумма выходов всех экспертов. Веса для взвешивания мы получаем от маршрутизатора (по каждому токену и каждому эксперту).\n", + "\n", + "> Тут нужно сделать лирическое отступление: хотя сами вычисления происходят заметно быстрее (из-за того что в единицу времени отрабатывает только часть экспертов), но вам все равно потребуется разместить в памяти GPU абсолютно всех экспертов одновременно.\n", + "\n", + "### Экспертность\n", + "\n", + "Технически при обучении LLM эксперты в MoE никак принудительно не «специализируются». Т.е. все эксперты независимы и равны между собой. Но исследования показывают, что эксперты часто самоорганизуются и сами специализируются на определенных типах данных или задачах:\n", + "\n", + "* Эксперт 1 — хорошо работает с техническим текстом.\n", + "* Эксперт 2 — лучше справляется с художественным текстом.\n", + "* Эксперт 3 — специализируется на грамматике.\n", + "* Эксперт 4 — хорошо обрабатывает имена собственные.\n", + "* Эксперт 5 — лучше понимает азиатские языки и т.д.\n", + "\n", + "И встречая очередной токен, роутер, на основе его содержимого, решает, к какому эксперту его отправить.\n", + "\n", + "> При обучении роутер может неравномерно распределять внимание между экспертами. Может даже отправлять все данные к одному эксперту. Лечится это специальными балансными лоссами.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8b5851df", + "metadata": {}, + "source": [ + "**MoE класс (разработка)**" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "982c3595", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "\n", + "class SiLU(nn.Module):\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n", + " return torch.sigmoid(x) * x\n", + " \n", + "class SwiGLU(nn.Module):\n", + " def __init__(self, emb_size: int, dropout: float = 0.1):\n", + " super().__init__()\n", + "\n", + " self._gate = nn.Linear(emb_size, 4 * emb_size)\n", + " self._up = nn.Linear(emb_size, 4 * emb_size)\n", + " self._down = nn.Linear(4 * emb_size, emb_size)\n", + " self._activation = SiLU()\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n", + " gate_out = self._gate(x) # [batch, seq, 4*emb]\n", + " activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n", + " up_out = self._up(x) # [batch, seq, 4*emb]\n", + " out = up_out * activation_out # поэлементное!\n", + " out = self._down(out) # [batch, seq, emb]\n", + " return self._dropout(out)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "34fac9f7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "class MoE(nn.Module):\n", + " def __init__(\n", + " self,\n", + " emb_size: int,\n", + " num_experts: int,\n", + " top_k_experts: int,\n", + " dropout: float = 0.1,\n", + " ):\n", + " super().__init__()\n", + " self._num_experts = num_experts\n", + " self._top_k_experts = top_k_experts\n", + "\n", + " self._router = nn.Linear(emb_size, num_experts)\n", + " self._experts = nn.ModuleList([SwiGLU(\n", + " emb_size=emb_size,\n", + " dropout=dropout,\n", + " ) for _ in range(num_experts)])\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " batch_size, seq_len, emb_size = x.shape\n", + " \n", + " # 1. Пропускаем через роутер\n", + " router_logits = self._router(x) # [batch_size, seq_len, num_experts]\n", + " \n", + " # 2. Отбираем топ-k экспертов для каждого токена\n", + " topk_logits, topk_indices = torch.topk(\n", + " router_logits, \n", + " k=self._top_k_experts, \n", + " dim=-1\n", + " ) # topk_logits: [batch_size, seq_len, top_k]\n", + " # topk_indices: [batch_size, seq_len, top_k]\n", + " \n", + " # 3. Получаем веса через softmax и нормируем\n", + " topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]\n", + " \n", + " # 4. Создаём нулевой тензор для результата\n", + " output = torch.zeros_like(x) # [batch_size, seq_len, emb_size] \n", + "\n", + " # 5. Проходим по всем экспертам\n", + " for expert_id in range(self._num_experts):\n", + " # Шаг 1: Создаём маску - где находится текущий эксперт в топ-k\n", + " expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]\n", + " # Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном\n", + " if not expert_mask.any():\n", + " continue # Эксперт никем не выбран, переходим к следующему\n", + "\n", + " # Шаг 3: Находим токены, которые выбрали этого эксперта\n", + " # (хотя бы в одной из top_k позиций)\n", + " token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]\n", + "\n", + " # Шаг 4: Отбираем токены из x\n", + " # Отбираем токены для этого эксперта\n", + " expert_input = x[token_mask]\n", + "\n", + " # Пропускаем через эксперта\n", + " # Добавляем batch dimension для SwiGLU и затем убираем\n", + " expert_output = self._experts[expert_id](\n", + " expert_input.unsqueeze(0)\n", + " ).squeeze(0)\n", + "\n", + " # Получаем веса для этого эксперта\n", + " # Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)\n", + " # Но на практике каждый эксперт появляется максимум 1 раз в топ-k\n", + " # Находим веса: где expert_mask == True, берём соответствующий вес\n", + " weights_for_expert = torch.zeros(\n", + " batch_size, seq_len, device=x.device\n", + " )\n", + "\n", + " # Для каждой позиции в топ-k\n", + " for k in range(self._top_k_experts):\n", + " mask_k = topk_indices[:, :, k] == expert_id\n", + " weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]\n", + "\n", + " # Отбираем только веса для выбранных токенов\n", + " selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]\n", + "\n", + "\n", + " # Перемножьте выход эксперта на веса текущего эксперта.\n", + " weighted_output = selected_weights.unsqueeze(-1) * expert_output\n", + "\n", + " # Помещаем результат на своё место в выходном тензоре\n", + " output[token_mask] += weighted_output\n", + " \n", + " out = self._dropout(output)\n", + "\n", + " return out" + ] + }, + { + "cell_type": "markdown", + "id": "d9135164", + "metadata": {}, + "source": [] + }, + { + "cell_type": "markdown", + "id": "bef0c904", + "metadata": {}, + "source": [ + "## Full Model" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "35e52050", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch import Tensor\n", + "import torch.nn.functional as F\n", + "from math import sqrt\n", + "\n", + "\n", + "\n", + "class SiLU(nn.Module):\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n", + " return torch.sigmoid(x) * x\n", + " \n", + "class RMSNorm(nn.Module):\n", + " def __init__(self, dim: int, eps: float = 1e-6):\n", + " super().__init__()\n", + " self._eps = eps\n", + " self._w = nn.Parameter(torch.ones(dim))\n", + " \n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size]\n", + " rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5\n", + " norm_x = x / rms\n", + " return self._w * norm_x\n", + "\n", + "class SwiGLU(nn.Module):\n", + " def __init__(self, emb_size: int, dropout: float = 0.1):\n", + " super().__init__()\n", + "\n", + " self._gate = nn.Linear(emb_size, 4 * emb_size)\n", + " self._up = nn.Linear(emb_size, 4 * emb_size)\n", + " self._down = nn.Linear(4 * emb_size, emb_size)\n", + " self._activation = SiLU()\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(self, x: torch.Tensor): # [batch_size × seq_len × emb_size].\n", + " gate_out = self._gate(x) # [batch, seq, 4*emb]\n", + " activation_out = self._activation(gate_out) # [batch, seq, 4*emb]\n", + " up_out = self._up(x) # [batch, seq, 4*emb]\n", + " out = up_out * activation_out # поэлементное!\n", + " out = self._down(out) # [batch, seq, emb]\n", + " return self._dropout(out)\n", + "\n", + "\n", + "class TokenEmbeddings(nn.Module):\n", + " def __init__(self, vocab_size: int, emb_size: int):\n", + " super().__init__()\n", + " self._embedding = nn.Embedding(\n", + " num_embeddings=vocab_size,\n", + " embedding_dim=emb_size\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " return self._embedding(x)\n", + "\n", + " @property\n", + " def num_embeddings(self) -> int:\n", + " return self._embedding.num_embeddings\n", + "\n", + " @property\n", + " def embedding_dim(self) -> int:\n", + " return self._embedding.embedding_dim\n", + "\n", + "\n", + "class GELU(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)\n", + " \n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " return 0.5 * x * (1 + torch.tanh(\n", + " self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))\n", + " ))\n", + "\n", + " \n", + " \n", + "import torch\n", + "from torch import nn\n", + "from typing import Optional\n", + "\n", + "\n", + "class RoPE(nn.Module):\n", + "\n", + " def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):\n", + " super().__init__()\n", + " assert head_size % 2 == 0, \"head_size должен быть четным\"\n", + "\n", + " # Вычисление частот: θ_i = base^(-2i/d) для i ∈ [0, d/2-1]\n", + " freqs = 1.0 / (base ** (2 * torch.arange(head_size // 2).float() / head_size))\n", + "\n", + " # Позиции от 0 до max_seq_len-1\n", + " positions = torch.arange(max_seq_len).float()\n", + "\n", + " # Внешнее произведение: m * θ_i для всех позиций и частот\n", + " freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)\n", + "\n", + " # Предвычисление матриц косинусов и синусов\n", + " self.register_buffer(\"cos_matrix\", torch.cos(freq_matrix))\n", + " self.register_buffer(\"sin_matrix\", torch.sin(freq_matrix))\n", + "\n", + " def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor: # [batch_size × seq_len × head_size] [batch_size × num_heads × seq_len × head_size]\n", + " batch_size, num_heads, seq_len, head_size = x.shape\n", + "\n", + " # Берем нужную часть матриц и приводим к типу x\n", + " cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n", + " sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]\n", + "\n", + " # Явное изменение формы для broadcasting\n", + " cos = cos.reshape(1, 1, seq_len, head_size // 2)\n", + " sin = sin.reshape(1, 1, seq_len, head_size // 2)\n", + "\n", + " # Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению\n", + " x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]\n", + " x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]\n", + "\n", + " # Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)\n", + " x_rotated_even = x_even * cos - x_odd * sin\n", + " x_rotated_odd = x_even * sin + x_odd * cos\n", + "\n", + " # Объединяем обратно в исходную размерность\n", + " x_rotated = torch.stack([x_rotated_even, x_rotated_odd], dim=-1)\n", + " x_rotated = x_rotated.flatten(-2) # [batch_size, seq_len, head_size]\n", + "\n", + " return x_rotated\n", + "\n", + "\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "from typing import Optional, Tuple\n", + "\n", + "\n", + " \n", + "class GroupedQueryAttention(nn.Module):\n", + "\n", + " def __init__(\n", + " self,\n", + " num_q_heads: int,\n", + " num_kv_heads: int,\n", + " emb_size: int,\n", + " head_size: int,\n", + " max_seq_len: int,\n", + " window_size: int,\n", + " rope: RoPE = None,\n", + " dropout: float = 0.1,\n", + " ):\n", + " super().__init__()\n", + " self._num_heads = num_q_heads\n", + " self._num_kv_heads = num_kv_heads\n", + " self._head_size = head_size\n", + " self._max_seq_len = max_seq_len\n", + " self._rope = rope\n", + " self._window_size = window_size\n", + "\n", + " self._q = nn.Linear(emb_size, self._num_heads * head_size)\n", + " self._k = nn.Linear(emb_size, num_kv_heads * head_size)\n", + " self._v = nn.Linear(emb_size, num_kv_heads * head_size)\n", + "\n", + " # Создание causal маски\n", + " mask = self._create_sliding_window_mask(max_seq_len, self._window_size)\n", + " self.register_buffer(\n", + " \"_tril_mask\", mask.bool() if hasattr(torch, \"bool\") else mask.byte()\n", + " )\n", + " \n", + " self._layer = nn.Linear(head_size * self._num_heads, emb_size)\n", + " self._dropout = nn.Dropout(dropout)\n", + "\n", + " def forward(\n", + " self,\n", + " x: torch.Tensor,\n", + " mask: torch.Tensor = None,\n", + " use_cache: bool = True,\n", + " cache: list = None,\n", + " ):\n", + " batch_size, seq_len, emb_size = x.shape\n", + "\n", + " if seq_len > self._max_seq_len:\n", + " raise ValueError(\n", + " f\"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}\"\n", + " )\n", + "\n", + " # Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.\n", + " k = self._k(x) # [B, T, hs]\n", + " q = self._q(x) # [B, T, hs]\n", + " v = self._v(x) # [B, T, hs]\n", + "\n", + " # Шаг 2: Изменение формы для multi-head\n", + " # [batch_size, seq_len, num_heads * head_size] \n", + " # -> [batch_size, seq_len, num_heads, head_size]\n", + " # Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.\n", + " q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)\n", + "\n", + " # Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.\n", + " k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n", + " v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)\n", + " \n", + "\n", + " # 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]\n", + " q = q.transpose(1, 2)\n", + " k = k.transpose(1, 2)\n", + " v = v.transpose(1, 2)\n", + "\n", + " start_pos = 0\n", + " if cache is not None:\n", + " k_cache, v_cache = cache\n", + " cache_len = k_cache.shape[2]\n", + " start_pos = cache_len\n", + " \n", + " # Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.\n", + " if self._rope is not None:\n", + " # Применяем RoPE к Q и K (НЕ к V!)\n", + " q = self._rope(q, start_pos=start_pos) # [B, T, hs]\n", + " k = self._rope(k, start_pos=start_pos) # [B, T, hs]\n", + "\n", + " # Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.\n", + " # 5. Кэширование (для autoregressive generation)\n", + " if cache is not None:\n", + " k_cache, v_cache = cache\n", + " k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)\n", + " v = torch.cat([v_cache, v], dim=2)\n", + "\n", + " # Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).\n", + " #if use_cache == True:\n", + " # # Обрезаем до последних window_size токенов\n", + " # k_to_cache = k[:, :, -self._window_size:, :]\n", + " # v_to_cache = v[:, :, -self._window_size:, :]\n", + " # kv_cache = (k_to_cache, v_to_cache)\n", + "\n", + " # Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.\n", + " #k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n", + " #v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n", + " k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)\n", + " v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)\n", + " \n", + " # Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.\n", + " # И разделить все значения в матрице внимания на корень из head_size.\n", + " scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)\n", + "\n", + " # 8. Применение маски\n", + " k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем\n", + " \n", + " if cache is None:\n", + " # Случай 1: Без кэша - полная квадратная маска\n", + " # scores: [B, H, seq_len, seq_len]\n", + " # Применяем маску [:seq_len, :seq_len]\n", + " scores = scores.masked_fill(\n", + " ~self._tril_mask[:seq_len, :seq_len], \n", + " float(\"-inf\")\n", + " )\n", + "\n", + " # Применить к матрице внимания (построчно) функцию Softmax.\n", + " weights = F.softmax(scores, dim=-1)\n", + "\n", + " # Перемножим матрицу внимания и матрицу значения.\n", + " x_out = weights @ v_expanded # [B, T, hs]\n", + "\n", + " # Измените форму тензора на batch_size × seq_len × num_heads*head_size.\n", + " # Transpose обратно и concatenate heads\n", + " x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]\n", + " x_out = x_out.contiguous() # Важно для reshape!\n", + " concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n", + "\n", + " #concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)\n", + "\n", + " # Пропустите получившийся тензор через последний линейный слой.\n", + " # 3. Проецируем в пространство эмбеддингов\n", + " projected_output = self._layer(concatenated_attention)\n", + "\n", + " # 4. Применяем dropout для регуляризации\n", + " output = self._dropout(projected_output)\n", + "\n", + " if use_cache:\n", + " # Обрезаем оригинальный K и V (до дублирования)\n", + " k_to_cache = k[:, :, -self._window_size:, :]\n", + " v_to_cache = v[:, :, -self._window_size:, :]\n", + " kv_cache = (k_to_cache, v_to_cache)\n", + " return output, kv_cache\n", + " else:\n", + " return output, None\n", + "\n", + " def _repeat_kv_heads(\n", + " self,\n", + " kv: torch.Tensor,\n", + " num_q_heads: int,\n", + " num_kv_heads: int\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Дублирует головы K/V для соответствия количеству голов Q.\n", + "\n", + " Args:\n", + " kv: [batch_size, num_kv_heads, seq_len, head_size]\n", + " num_q_heads: Количество голов Query (например, 8)\n", + " num_kv_heads: Количество голов Key/Value (например, 2)\n", + "\n", + " Returns:\n", + " [batch_size, num_q_heads, seq_len, head_size]\n", + "\n", + " Example:\n", + " num_q_heads=8, num_kv_heads=2\n", + " Каждая голова KV дублируется 4 раза:\n", + " [KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]\n", + " \"\"\"\n", + " batch_size, num_kv_heads, seq_len, head_size = kv.shape\n", + "\n", + " if num_q_heads == num_kv_heads:\n", + " # Нет необходимости дублировать\n", + " return kv\n", + "\n", + " # Вычисляем сколько раз нужно повторить каждую голову\n", + " num_repeats = num_q_heads // num_kv_heads\n", + "\n", + " # repeat_interleave дублирует каждую голову num_repeats раз\n", + " # [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]\n", + " # [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]\n", + " kv = kv.unsqueeze(2)\n", + " \n", + " # [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]\n", + " kv = kv.repeat(1, 1, num_repeats, 1, 1)\n", + " \n", + " # [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]\n", + " kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)\n", + " \n", + "\n", + " return kv\n", + "\n", + " def _create_sliding_window_mask(\n", + " self,\n", + " max_seq_len: int,\n", + " window_size: int,\n", + " device: torch.device = None\n", + " ) -> torch.Tensor:\n", + " \"\"\"\n", + " Создает маску для Sliding Window Attention.\n", + "\n", + " Args:\n", + " max_seq_len: Максимальная длина последовательности\n", + " window_size: Размер окна внимания\n", + " device: Устройство для размещения тензора\n", + "\n", + " Returns:\n", + " Маска формы [max_seq_len, max_seq_len], где True = разрешено\n", + "\n", + " Example:\n", + " >>> mask = create_sliding_window_mask(8, 3)\n", + " >>> print(mask.int())\n", + " tensor([[1, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 0, 0, 0, 0, 0],\n", + " [0, 1, 1, 1, 0, 0, 0, 0],\n", + " [0, 0, 1, 1, 1, 0, 0, 0],\n", + " [0, 0, 0, 1, 1, 1, 0, 0],\n", + " [0, 0, 0, 0, 1, 1, 1, 0],\n", + " [0, 0, 0, 0, 0, 1, 1, 1]])\n", + " \"\"\"\n", + " row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]\n", + " col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]\n", + "\n", + " causal_mask = col_indices <= row_indices\n", + "\n", + " window_mask = (row_indices - col_indices) <= window_size\n", + "\n", + " mask = causal_mask & window_mask\n", + " \n", + " return mask\n", + "\n", + "class Decoder(nn.Module):\n", + " def __init__(self, \n", + " num_q_heads: int,\n", + " num_kv_heads: int,\n", + " emb_size: int,\n", + " head_size: int,\n", + " max_seq_len: int,\n", + " num_experts: int,\n", + " top_k_experts: int,\n", + " window_size: int,\n", + " rope: RoPE,\n", + " dropout: float = 0.1\n", + " ):\n", + " super().__init__()\n", + " self._heads = GroupedQueryAttention(\n", + " num_q_heads=num_q_heads, \n", + " num_kv_heads=num_kv_heads,\n", + " emb_size=emb_size, \n", + " head_size=head_size, \n", + " max_seq_len=max_seq_len,\n", + " window_size=window_size,\n", + " rope=rope,\n", + " dropout=dropout\n", + " )\n", + " self._ff = MoE(\n", + " emb_size=emb_size, \n", + " num_experts=num_experts,\n", + " top_k_experts=top_k_experts,\n", + " dropout=dropout\n", + " )\n", + " self._norm1 = RMSNorm(emb_size)\n", + " self._norm2 = RMSNorm(emb_size)\n", + "\n", + " def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:\n", + " norm1_out = self._norm1(x)\n", + " attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)\n", + " out = attention + x\n", + " \n", + " norm2_out = self._norm2(out)\n", + " ffn_out = self._ff(norm2_out)\n", + "\n", + " if use_cache is True:\n", + " return (ffn_out + out, kv_caches)\n", + " else:\n", + " return (ffn_out + out, None)\n", + "\n", + "\n", + "\n", + "from torch import nn\n", + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "class Mixtral(nn.Module):\n", + " def __init__(self,\n", + " vocab_size: int,\n", + " max_seq_len: int,\n", + " emb_size: int,\n", + " num_q_heads: int,\n", + " num_kv_heads: int,\n", + " head_size: int,\n", + " num_layers: int,\n", + " num_experts: int,\n", + " top_k_experts: int,\n", + " window_size: int,\n", + " dropout: float = 0.1,\n", + " device: str = 'cpu'\n", + " ):\n", + " super().__init__()\n", + " self._vocab_size = vocab_size\n", + " self._max_seq_len = max_seq_len\n", + " self._emb_size = emb_size\n", + " self._num_q_heads = num_q_heads\n", + " self._num_kv_heads = num_kv_heads\n", + " self._head_size = head_size\n", + " self._num_layers = num_layers\n", + " self._dropout = dropout\n", + " self._device = device\n", + " \n", + " self.validation_loss = None\n", + "\n", + " # Инициализация слоев\n", + " self._token_embeddings = TokenEmbeddings(\n", + " vocab_size=vocab_size, \n", + " emb_size=emb_size\n", + " )\n", + " self._position_embeddings = RoPE(\n", + " head_size=head_size,\n", + " max_seq_len=max_seq_len\n", + " )\n", + " #self._position_embeddings = PositionalEmbeddings(\n", + " # max_seq_len=max_seq_len, \n", + " # emb_size=emb_size\n", + " #)\n", + " self._dropout = nn.Dropout(dropout)\n", + " self._decoders = nn.ModuleList([Decoder(\n", + " num_q_heads=num_q_heads,\n", + " num_kv_heads=num_kv_heads,\n", + " emb_size=emb_size,\n", + " head_size=head_size,\n", + " max_seq_len=max_seq_len,\n", + " num_experts=num_experts,\n", + " top_k_experts=top_k_experts,\n", + " window_size=window_size,\n", + " rope=self._position_embeddings,\n", + " dropout=dropout \n", + " ) for _ in range(num_layers)])\n", + " self._norm = RMSNorm(emb_size)\n", + " self._linear = nn.Linear(emb_size, vocab_size)\n", + "\n", + " def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:\n", + " # Проверка длины последовательности (только при отсутствии кэша)\n", + " if cache is None and x.size(1) > self._max_seq_len:\n", + " raise ValueError(f\"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}\")\n", + " \n", + " # Эмбеддинги токенов и позиций\n", + " tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]\n", + " #pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]\n", + " \n", + " # Комбинирование\n", + " out = self._dropout(tok_out) # [batch, seq_len, emb_size]\n", + " \n", + " # Стек декодеров с передачей кэша\n", + " new_cache = []\n", + " for i, decoder in enumerate(self._decoders):\n", + " decoder_cache = cache[i] if cache is not None else None\n", + " decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)\n", + "\n", + " # Извлекаем результат из кортежа\n", + " if use_cache:\n", + " out, decoder_new_cache = decoder_result\n", + " new_cache.append(decoder_new_cache)\n", + " else:\n", + " out = decoder_result[0]\n", + "\n", + " out = self._norm(out)\n", + " logits = self._linear(out)\n", + " \n", + " # Возвращаем результат с учетом use_cache\n", + " if use_cache:\n", + " return (logits, new_cache)\n", + " else:\n", + " return (logits, None)\n", + "\n", + " def generate(self,\n", + " x: torch.Tensor, \n", + " max_new_tokens: int, \n", + " do_sample: bool,\n", + " temperature: float = 1.0,\n", + " top_k: int = None,\n", + " top_p: float = None,\n", + " use_cache: bool = True\n", + " ) -> torch.Tensor:\n", + " cache = None\n", + "\n", + " for _ in range(max_new_tokens):\n", + " if use_cache and cache is not None:\n", + " # Используем кэш - передаем только последний токен\n", + " x_input = x[:, -1:] # [batch_size, 1]\n", + " else:\n", + " # Первая итерация или кэш отключен - передаем всю последовательность\n", + " x_input = x\n", + " \n", + " # Прямой проход с кэшем\n", + " logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)\n", + " \n", + " # Обновляем кэш для следующей итерации\n", + " if use_cache:\n", + " cache = new_cache\n", + "\n", + " last_logits = logits[:, -1, :] # [batch_size, vocab_size]\n", + "\n", + " # Масштабируем логиты температурой\n", + " if temperature > 0:\n", + " logits_scaled = last_logits / temperature\n", + " else:\n", + " logits_scaled = last_logits\n", + "\n", + " if do_sample == True and top_k != None:\n", + " _, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)\n", + "\n", + " # # Заменим все НЕ top-k логиты на -inf\n", + " masked_logits = logits_scaled.clone()\n", + " vocab_size = logits_scaled.size(-1)\n", + "\n", + " # создаём маску: 1, если токен НЕ в topk_indices\n", + " mask = torch.ones_like(logits_scaled, dtype=torch.uint8)\n", + " mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы\n", + " masked_logits[mask.byte()] = float('-inf')\n", + "\n", + " logits_scaled = masked_logits\n", + "\n", + " if do_sample == True and top_p != None:\n", + " # 1. Применим softmax, чтобы получить вероятности:\n", + " probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]\n", + " # 2. Отсортируем токены по убыванию вероятностей:\n", + " sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)\n", + " # 3. Посчитаем кумулятивную сумму вероятностей:\n", + " cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]\n", + " # 4. Определим маску: оставить токены, пока сумма < top_p\n", + " sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]\n", + " # Гарантируем, что хотя бы первый токен останется\n", + " sorted_mask[:, 0] = 1\n", + " # 5. Преобразуем маску обратно в оригинальный порядок:\n", + " # Создаём полную маску из 0\n", + " mask = torch.zeros_like(probs, dtype=torch.uint8)\n", + " # Устанавливаем 1 в местах нужных токенов\n", + " mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)\n", + " # 6. Зануляем логиты токенов вне топ-p:\n", + " logits_scaled[~mask] = float('-inf')\n", + "\n", + " # 4. Применяем Softmax\n", + " probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]\n", + "\n", + "\n", + " if do_sample == True:\n", + " # 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial\n", + " next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]\n", + " else:\n", + " # 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью\n", + " next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]\n", + " \n", + " # 6. Добавляем его к последовательности\n", + " x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]\n", + " return x\n", + "\n", + " def save(self, path):\n", + " torch.save({\n", + " 'model_state_dict': self.state_dict(),\n", + " 'vocab_size': self._vocab_size,\n", + " 'max_seq_len': self._max_seq_len,\n", + " 'emb_size': self._emb_size,\n", + " 'num_heads': self._num_heads,\n", + " 'head_size': self._head_size,\n", + " 'num_layers': self._num_layers\n", + " }, path)\n", + "\n", + " @classmethod\n", + " def load(cls, path, device):\n", + " checkpoint = torch.load(path, map_location=device)\n", + " model = cls(\n", + " vocab_size=checkpoint['vocab_size'],\n", + " max_seq_len=checkpoint['max_seq_len'],\n", + " emb_size=checkpoint['emb_size'],\n", + " num_heads=checkpoint['num_heads'],\n", + " head_size=checkpoint['head_size'],\n", + " num_layers=checkpoint['num_layers']\n", + " )\n", + " model.load_state_dict(checkpoint['model_state_dict'])\n", + " model.to(device)\n", + " return model\n", + "\n", + " @property\n", + " def max_seq_len(self) -> int:\n", + " return self._max_seq_len" + ] + }, + { + "cell_type": "markdown", + "id": "7f4b3b1e", + "metadata": {}, + "source": [ + "## 2. Обучение Mixtral\n", + "\n", + "Mixtral обучается в два этапа:\n", + "\n", + "- 1️⃣ **Предобучение (Unsupervised Pretraining)** \n", + "- 2️⃣ **Дообучение (Supervised Fine-Tuning)**" + ] + }, + { + "cell_type": "markdown", + "id": "cb0a9172", + "metadata": {}, + "source": [ + "\n", + "\n", + "### 5.1 Предобучение\n", + "\n", + "На первом этапе модель обучается без разметки: она получает большой корпус текстов и учится **предсказывать следующий токен** по предыдущим.\n", + "\n", + "Функция потерь:\n", + "$$\n", + "L = - \\sum_{t=1}^{T} \\log P(x_t | x_1, x_2, ..., x_{t-1})\n", + "$$\n", + "\n", + "Таким образом, модель учится строить вероятностную модель языка, \"угадывая\" продолжение текста.\n" + ] + }, + { + "cell_type": "markdown", + "id": "b4bdbb31", + "metadata": {}, + "source": [ + "Во время **предобучения** Mistral учится **предсказывать следующий токен** (language modeling task). \n", + "Формально: \n", + "$$ \n", + "P(x_t ,|, x_1, x_2, \\dots, x_{t-1}) \n", + "$$ \n", + "То есть, если на вход подаётся предложение `\"I love deep\"`, модель должна предсказать `\"learning\"`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "379a5c21", + "metadata": {}, + "source": [ + "### ✅ 5.1.1 Подготовка данных\n", + "\n", + "Создадим **датасет** на основе BPE-токенизатора:" + ] + }, + { + "cell_type": "markdown", + "id": "8141ac57", + "metadata": {}, + "source": [ + "**BPE Tokenizator**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3fc41e15", + "metadata": {}, + "outputs": [], + "source": [ + "class BPE:\n", + " def __init__(self, vocab_size: int):\n", + " self.vocab_size = vocab_size\n", + " self.id2token = {}\n", + " self.token2id = {}\n", + "\n", + " def fit(self, text: str):\n", + " # 1. Получаем уникальные токены (символы)\n", + " unique_tokens = sorted(set(text))\n", + " tokens = unique_tokens.copy()\n", + "\n", + " # 2. Разбиваем текст на токены-символы\n", + " sequence = list(text)\n", + "\n", + " # 3. Объединяем токены до достижения нужного размера словаря\n", + " while len(tokens) < self.vocab_size:\n", + " #print(f'len={len(tokens)} < {self.vocab_size}')\n", + " # Считаем частоты пар\n", + " pair_freq = {}\n", + " for i in range(len(sequence) - 1):\n", + " pair = (sequence[i], sequence[i + 1])\n", + " #print(f'pair = {pair}')\n", + " if pair not in pair_freq:\n", + " pair_freq[pair] = 0\n", + " pair_freq[pair] += 1\n", + "\n", + "\n", + " #print(f'pair_freq = {pair_freq}') \n", + " if not pair_freq:\n", + " break # нет пар — выходим\n", + "\n", + " #for x in pair_freq.items():\n", + " # self.debug(x, sequence)\n", + "\n", + " # Находим самую частую пару (в случае равенства — та, что встретилась первой)\n", + " most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]\n", + " #print(most_frequent_pair)\n", + " # Создаем новый токен\n", + " new_token = most_frequent_pair[0] + most_frequent_pair[1]\n", + " #print(f\"new token={new_token}\")\n", + " tokens.append(new_token)\n", + " #print(f\"tokens={tokens}\")\n", + "\n", + " i = 0\n", + " new_sequence = []\n", + "\n", + " while i < len(sequence):\n", + " if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:\n", + " new_sequence.append(new_token)\n", + " i += 2 # пропускаем два символа — заменённую пару\n", + " else:\n", + " new_sequence.append(sequence[i])\n", + " i += 1\n", + " sequence = new_sequence\n", + " #break\n", + " \n", + " # 4. Создаем словари\n", + " self.vocab = tokens.copy()\n", + " self.token2id = dict(zip(tokens, range(self.vocab_size)))\n", + " self.id2token = dict(zip(range(self.vocab_size), tokens))\n", + "\n", + " def _pair_first_index(self, sequence, pair):\n", + " for i in range(len(sequence) - 1):\n", + " if (sequence[i], sequence[i + 1]) == pair:\n", + " return i\n", + " return float('inf') # если пара не найдена (в теории не должно случиться)\n", + "\n", + "\n", + " def encode(self, text: str):\n", + " # 1. Разбиваем текст на токены-символы\n", + " sequence = list(text)\n", + " # 2. Инициализация пустого списка токенов\n", + " tokens = []\n", + " # 3. Установить i = 0\n", + " i = 0\n", + " while i < len(text):\n", + " # 3.1 Найти все токены в словаре, начинающиеся с text[i]\n", + " start_char = text[i]\n", + " result = [token for token in self.vocab if token.startswith(start_char)]\n", + " # 3.2 Выбрать самый длинный подходящий токен\n", + " find_token = self._find_max_matching_token(text[i:], result)\n", + " if find_token is None:\n", + " # Обработка неизвестного символа\n", + " tokens.append(text[i]) # Добавляем сам символ как токен\n", + " i += 1\n", + " else:\n", + " # 3.3 Добавить токен в результат\n", + " tokens.append(find_token)\n", + " # 3.4 Увеличить i на длину токена\n", + " i += len(find_token)\n", + "\n", + " # 4. Заменить токены на их ID\n", + " return self._tokens_to_ids(tokens)\n", + "\n", + " def _find_max_matching_token(self, text: str, tokens: list):\n", + " \"\"\"Находит самый длинный токен из списка, с которого начинается текст\"\"\"\n", + " matching = [token for token in tokens if text.startswith(token)]\n", + " return max(matching, key=len) if matching else None\n", + "\n", + " def _tokens_to_ids(self, tokens):\n", + " \"\"\"Конвертирует список токенов в их ID с обработкой неизвестных токенов\"\"\"\n", + " ids = []\n", + " for token in tokens:\n", + " if token in self.token2id:\n", + " ids.append(self.token2id[token])\n", + " else:\n", + " ids.append(0) # Специальное значение\n", + " return ids\n", + "\n", + "\n", + " def decode(self, ids: list) -> str:\n", + " return ''.join(self._ids_to_tokens(ids))\n", + "\n", + " def _ids_to_tokens(self, ids: list) -> list:\n", + " \"\"\"Конвертирует список Ids в их tokens\"\"\"\n", + " tokens = []\n", + " for id in ids:\n", + " if id in self.id2token:\n", + " tokens.append(self.id2token[id])\n", + " else:\n", + " tokens.append('') # Специальное значение\n", + " return tokens\n", + "\n", + "\n", + " def save(self, filename):\n", + " with open(filename, 'wb') as f:\n", + " dill.dump(self, f)\n", + " print(f\"Объект сохранён в {filename}\")\n", + "\n", + "\n", + " @classmethod\n", + " def load(cls, filename):\n", + " with open(filename, 'rb') as f:\n", + " obj = dill.load(f)\n", + " \n", + " print(f\"Объект загружен из {filename}\")\n", + " return obj" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "dc7a3dbf", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch.utils.data import Dataset, DataLoader\n", + "\n", + "class GPTDataset(Dataset):\n", + " def __init__(self, text: str, bpe: BPE, block_size: int):\n", + " self.bpe = bpe\n", + " self.block_size = block_size\n", + " self.data = bpe.encode(text)\n", + " \n", + " def __len__(self):\n", + " return len(self.data) - self.block_size\n", + "\n", + " def __getitem__(self, idx):\n", + " x = torch.tensor(self.data[idx:idx+self.block_size], dtype=torch.long)\n", + " y = torch.tensor(self.data[idx+1:idx+self.block_size+1], dtype=torch.long)\n", + " return x, y" + ] + }, + { + "cell_type": "markdown", + "id": "e8981229", + "metadata": {}, + "source": [ + "### ✅ 5.1.2 Цикл обучения\n", + "\n", + "Для обучения создадим функцию:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "4e097434", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn.functional as F\n", + "from torch import optim\n", + "\n", + "def train_mixtral(model, dataset, epochs=5, batch_size=32, lr=3e-4, device='cpu'):\n", + " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", + " optimizer = optim.AdamW(model.parameters(), lr=lr)\n", + "\n", + " model.to(device)\n", + " model.train()\n", + "\n", + " for epoch in range(epochs):\n", + " total_loss = 0\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + "\n", + " # Прямой проход\n", + " logits, _ = model(x, use_cache=False) # [B, T, vocab_size]\n", + "\n", + " # Перестроим выход под CrossEntropy\n", + " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n", + "\n", + " # Обратное распространение\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " total_loss += loss.item()\n", + "\n", + " avg_loss = total_loss / len(dataloader)\n", + " print(f\"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}\")\n", + "\n", + " return model" + ] + }, + { + "cell_type": "markdown", + "id": "1f536de2", + "metadata": {}, + "source": [ + "### ✅ 5.1.3 Пример запуска\n", + "\n", + "\n", + "**🧠 Конфигурация Mistral Mini**\n", + "\n", + "\n", + "| Параметр | Значение | Описание |\n", + "| --------------- | -------- | --------------------------------------------- |\n", + "| **vocab_size** | `50257` | Размер словаря (BPE токенизатор OpenAI) |\n", + "| **max_seq_len** | `512` | Максимальная длина входной последовательности |\n", + "| **emb_size** | `256` | Размер эмбеддингов (векторное пространство) |\n", + "| **num_heads** | `4` | Количество голов в multi-head attention |\n", + "| **head_size** | `64` | Размерность одной головы внимания (768 / 12) |\n", + "| **num_layers** | `4` | Количество блоков (декодеров) |\n", + "| **dropout** | `0.1` | Вероятность дропаута |\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "83ed4be7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Dataset length: 20\n", + "Epoch 1/100, Loss: 3.5861\n", + "Epoch 2/100, Loss: 1.1913\n", + "Epoch 3/100, Loss: 0.4864\n", + "Epoch 4/100, Loss: 0.2490\n", + "Epoch 5/100, Loss: 0.1993\n", + "Epoch 6/100, Loss: 0.1350\n", + "Epoch 7/100, Loss: 0.1039\n", + "Epoch 8/100, Loss: 0.1051\n", + "Epoch 9/100, Loss: 0.0730\n", + "Epoch 10/100, Loss: 0.0754\n", + "Epoch 11/100, Loss: 0.0819\n", + "Epoch 12/100, Loss: 0.0664\n", + "Epoch 13/100, Loss: 0.0793\n", + "Epoch 14/100, Loss: 0.0668\n", + "Epoch 15/100, Loss: 0.0818\n", + "Epoch 16/100, Loss: 0.0734\n", + "Epoch 17/100, Loss: 0.0637\n", + "Epoch 18/100, Loss: 0.0584\n", + "Epoch 19/100, Loss: 0.0762\n", + "Epoch 20/100, Loss: 0.0683\n", + "Epoch 21/100, Loss: 0.0624\n", + "Epoch 22/100, Loss: 0.0557\n", + "Epoch 23/100, Loss: 0.0579\n", + "Epoch 24/100, Loss: 0.0558\n", + "Epoch 25/100, Loss: 0.0578\n", + "Epoch 26/100, Loss: 0.0520\n", + "Epoch 27/100, Loss: 0.0642\n", + "Epoch 28/100, Loss: 0.0660\n", + "Epoch 29/100, Loss: 0.0508\n", + "Epoch 30/100, Loss: 0.0542\n", + "Epoch 31/100, Loss: 0.0478\n", + "Epoch 32/100, Loss: 0.0597\n", + "Epoch 33/100, Loss: 0.0578\n", + "Epoch 34/100, Loss: 0.0557\n", + "Epoch 35/100, Loss: 0.0528\n", + "Epoch 36/100, Loss: 0.0521\n", + "Epoch 37/100, Loss: 0.0520\n", + "Epoch 38/100, Loss: 0.0512\n", + "Epoch 39/100, Loss: 0.0501\n", + "Epoch 40/100, Loss: 0.0497\n", + "Epoch 41/100, Loss: 0.0454\n", + "Epoch 42/100, Loss: 0.0519\n", + "Epoch 43/100, Loss: 0.0535\n", + "Epoch 44/100, Loss: 0.0464\n", + "Epoch 45/100, Loss: 0.0459\n", + "Epoch 46/100, Loss: 0.0437\n", + "Epoch 47/100, Loss: 0.0585\n", + "Epoch 48/100, Loss: 0.0469\n", + "Epoch 49/100, Loss: 0.0538\n", + "Epoch 50/100, Loss: 0.0592\n", + "Epoch 51/100, Loss: 0.0520\n", + "Epoch 52/100, Loss: 0.0582\n", + "Epoch 53/100, Loss: 0.0504\n", + "Epoch 54/100, Loss: 0.0471\n", + "Epoch 55/100, Loss: 0.0478\n", + "Epoch 56/100, Loss: 0.0487\n", + "Epoch 57/100, Loss: 0.0507\n", + "Epoch 58/100, Loss: 0.0500\n", + "Epoch 59/100, Loss: 0.0457\n", + "Epoch 60/100, Loss: 0.0493\n", + "Epoch 61/100, Loss: 0.0431\n", + "Epoch 62/100, Loss: 0.0503\n", + "Epoch 63/100, Loss: 0.0436\n", + "Epoch 64/100, Loss: 0.0512\n", + "Epoch 65/100, Loss: 0.0488\n", + "Epoch 66/100, Loss: 0.0436\n", + "Epoch 67/100, Loss: 0.0505\n", + "Epoch 68/100, Loss: 0.0389\n", + "Epoch 69/100, Loss: 0.0447\n", + "Epoch 70/100, Loss: 0.0442\n", + "Epoch 71/100, Loss: 0.0443\n", + "Epoch 72/100, Loss: 0.0460\n", + "Epoch 73/100, Loss: 0.0483\n", + "Epoch 74/100, Loss: 0.0538\n", + "Epoch 75/100, Loss: 0.0437\n", + "Epoch 76/100, Loss: 0.0500\n", + "Epoch 77/100, Loss: 0.0443\n", + "Epoch 78/100, Loss: 0.0461\n", + "Epoch 79/100, Loss: 0.0467\n", + "Epoch 80/100, Loss: 0.0472\n", + "Epoch 81/100, Loss: 0.0507\n", + "Epoch 82/100, Loss: 0.0454\n", + "Epoch 83/100, Loss: 0.0414\n", + "Epoch 84/100, Loss: 0.0468\n", + "Epoch 85/100, Loss: 0.0525\n", + "Epoch 86/100, Loss: 0.0456\n", + "Epoch 87/100, Loss: 0.0512\n", + "Epoch 88/100, Loss: 0.0487\n", + "Epoch 89/100, Loss: 0.0452\n", + "Epoch 90/100, Loss: 0.0436\n", + "Epoch 91/100, Loss: 0.0488\n", + "Epoch 92/100, Loss: 0.0471\n", + "Epoch 93/100, Loss: 0.0442\n", + "Epoch 94/100, Loss: 0.0435\n", + "Epoch 95/100, Loss: 0.0502\n", + "Epoch 96/100, Loss: 0.0430\n", + "Epoch 97/100, Loss: 0.0430\n", + "Epoch 98/100, Loss: 0.0424\n", + "Epoch 99/100, Loss: 0.0421\n", + "Epoch 100/100, Loss: 0.0410\n" + ] + }, + { + "data": { + "text/plain": [ + "Mixtral(\n", + " (_token_embeddings): TokenEmbeddings(\n", + " (_embedding): Embedding(100, 256)\n", + " )\n", + " (_position_embeddings): RoPE()\n", + " (_dropout): Dropout(p=0.1, inplace=False)\n", + " (_decoders): ModuleList(\n", + " (0-3): 4 x Decoder(\n", + " (_heads): GroupedQueryAttention(\n", + " (_rope): RoPE()\n", + " (_q): Linear(in_features=256, out_features=256, bias=True)\n", + " (_k): Linear(in_features=256, out_features=128, bias=True)\n", + " (_v): Linear(in_features=256, out_features=128, bias=True)\n", + " (_layer): Linear(in_features=256, out_features=256, bias=True)\n", + " (_dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (_ff): MoE(\n", + " (_router): Linear(in_features=256, out_features=8, bias=True)\n", + " (_experts): ModuleList(\n", + " (0-7): 8 x SwiGLU(\n", + " (_gate): Linear(in_features=256, out_features=1024, bias=True)\n", + " (_up): Linear(in_features=256, out_features=1024, bias=True)\n", + " (_down): Linear(in_features=1024, out_features=256, bias=True)\n", + " (_activation): SiLU()\n", + " (_dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (_dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (_norm1): RMSNorm()\n", + " (_norm2): RMSNorm()\n", + " )\n", + " )\n", + " (_norm): RMSNorm()\n", + " (_linear): Linear(in_features=256, out_features=100, bias=True)\n", + ")" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 1. Исходный текст\n", + "text = \"Deep learning is amazing. Transformers changed the world. Attention is all you need. GPT models revolutionized NLP.\"\n", + "\n", + "# 2. Обучаем токенизатор\n", + "bpe = BPE(vocab_size=100)\n", + "bpe.fit(text)\n", + "\n", + "# 3. Создаем датасет\n", + "dataset = GPTDataset(text, bpe, block_size=8)\n", + "print(f\"Dataset length: {len(dataset)}\")\n", + "\n", + "# 4. Инициализируем модель\n", + "model = Mixtral(\n", + " vocab_size=len(bpe.vocab), # размер словаря BPE\n", + " max_seq_len=512, # GPT-2 использует контекст в 512 токена\n", + " emb_size=256, # размер эмбеддингов\n", + " num_q_heads=4, # количество голов внимания\n", + " num_kv_heads=2, # количество голов внимания\n", + " head_size=64, # размер каждой головы (256 / 4)\n", + " num_layers=4, # количество блоков Transformer\n", + " num_experts=8,\n", + " top_k_experts=2,\n", + " window_size=8,\n", + " dropout=0.1 # стандартный dropout GPT-2\n", + ")\n", + "\n", + "# 5. Обучаем\n", + "train_mixtral(model, dataset, epochs=100, batch_size=4)" + ] + }, + { + "cell_type": "markdown", + "id": "72e91ceb", + "metadata": {}, + "source": [ + "\n", + "---\n", + "\n", + "### 5.2 Дообучение\n", + "\n", + "После предобучения LLAMA уже знает структуру и грамматику языка. \n", + "На втором этапе она дообучается на конкретных задачах (например, классификация, QA) с помощью размеченных данных.\n", + "\n", + "Технически это почти то же обучение, только:\n", + "\n", + "- Загружаем модель с уже обученными весами.\n", + "- Используем новые данные.\n", + "- Можно уменьшить скорость обучения.\n", + "- Иногда замораживают часть слоёв (например, эмбеддинги).\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "0a6f9730", + "metadata": {}, + "outputs": [], + "source": [ + "def fine_tune_mixtral(model, dataset, epochs=3, batch_size=16, lr=1e-5, device='cpu', freeze_embeddings=True):\n", + " if freeze_embeddings:\n", + " for param in model._token_embeddings.parameters():\n", + " param.requires_grad = False\n", + " for param in model._position_embeddings.parameters():\n", + " param.requires_grad = False\n", + "\n", + " dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)\n", + " optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)\n", + "\n", + " model.to(device)\n", + " model.train()\n", + "\n", + " for epoch in range(epochs):\n", + " total_loss = 0\n", + " for x, y in dataloader:\n", + " x, y = x.to(device), y.to(device)\n", + " logits, _ = model(x, use_cache=False)\n", + " loss = F.cross_entropy(logits.view(-1, logits.size(-1)), y.view(-1))\n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " total_loss += loss.item()\n", + " print(f\"Fine-tune Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c9e28f47", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Fine-tune Epoch 1/10, Loss: 4.8596\n", + "Fine-tune Epoch 2/10, Loss: 2.6883\n", + "Fine-tune Epoch 3/10, Loss: 1.5315\n", + "Fine-tune Epoch 4/10, Loss: 1.1258\n", + "Fine-tune Epoch 5/10, Loss: 0.9248\n", + "Fine-tune Epoch 6/10, Loss: 0.7725\n", + "Fine-tune Epoch 7/10, Loss: 0.6405\n", + "Fine-tune Epoch 8/10, Loss: 0.5367\n", + "Fine-tune Epoch 9/10, Loss: 0.4469\n", + "Fine-tune Epoch 10/10, Loss: 0.4168\n" + ] + } + ], + "source": [ + "# Например, мы хотим дообучить модель на стиле коротких технических фраз\n", + "fine_tune_text = \"\"\"\n", + "Transformers revolutionize NLP.\n", + "Deep learning enables self-attention.\n", + "GPT generates text autoregressively.\n", + "\"\"\"\n", + "\n", + "dataset = GPTDataset(fine_tune_text, bpe, block_size=8)\n", + "\n", + "\n", + "# Запуск дообучения\n", + "fine_tune_mixtral(model, dataset, epochs=10, batch_size=4, lr=1e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "fd470fdd", + "metadata": {}, + "source": [ + "## 📝 6. Генерация текста после обучения" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "245dd064", + "metadata": {}, + "outputs": [], + "source": [ + "def generate_text(model, bpe, prompt: str, max_new_tokens=20, device='cpu'):\n", + " model.eval()\n", + " ids = torch.tensor([bpe.encode(prompt)], dtype=torch.long).to(device)\n", + " out = model.generate(ids, max_new_tokens=max_new_tokens, do_sample=True)\n", + " text = bpe.decode(out[0].tolist())\n", + " return text" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "1f5db85f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Deep learning ena les self atelf a\n" + ] + } + ], + "source": [ + "print(generate_text(model, bpe, \"Deep learning\", max_new_tokens=20))" + ] + }, + { + "cell_type": "markdown", + "id": "094ab394", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}