44 Commits

Author SHA1 Message Date
Sergey Penkovsky
7744658716 Merge pull request #6 from pese-git/ref/gpt1
Ref/gpt1
2025-10-31 09:15:54 +03:00
Sergey Penkovsky
21cfd79c19 refactor(assets): update and reorganize GPT-1 architecture diagrams
- Renamed GPT-1 main scheme files for clarity
- Added new diagram files for attention, decoder, embeddings, and forward blocks (both .drawio and .png)
- Removed deprecated files (gpt11.drawio, gpt1.svg)
- Updated notebooks/gpt.ipynb with relevant changes
2025-10-30 14:40:31 +03:00
Sergey Penkovsky
9e2796e6be docs(gpt1): add architecture diagrams and notebook updates
- Added architecture diagrams for GPT-1: gpt1.drawio, gpt11.drawio (drawio format)
- Exported visualization images: gpt1.png, gpt1.svg for documentation and presentations
- Updated gpt.ipynb notebook to reference new materials and possibly add explanations of layers/logic
- New assets help to clarify model structure and training flow for both contributors and external users
2025-10-24 17:42:11 +03:00
Sergey Penkovsky
25caf69ced refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1
- Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification)
- Adapted all usages in GPT model to use new decoder structure and handle tuple output
- Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before
- Improved clarity and future extensibility for autoregressive generation and benchmarking
- No changes to architectural details or training loop; pure API and test modernization
2025-10-22 16:27:08 +03:00
Sergey Penkovsky
ddc4924a37 refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms. 2025-10-22 11:57:26 +03:00
Sergey Penkovsky
92a34551b8 Merge pull request #5 from pese-git/feature/gemma
Feature/gemma
2025-10-21 17:53:55 +03:00
Sergey Penkovsky
ea932a36f3 feat(gemma): document and test GeGLU, MultiQueryAttention, GemmaDecoder, update Gemma model docs
- Add new core modules: GeGLU (Gated GELU Linear Unit), GemmaDecoder, MultiQueryAttention; all with highly detailed scientific (RU) docstrings: theory, usage, formulas, references
- Major doc improvements in Gemma model: class, __init__, forward, generate now have full educational/engineering docstrings, use-case samples, and literature links
- Add comprehensive unit tests:
    * tests/core/test_geglu.py: GeGLU coverage (shape, grads, edge, repeat, float16/skip)
    * tests/core/test_gemma_decoder.py: GemmaDecoder coverage (shape, mask, cache, repeatability, errors)
    * tests/core/test_multi_query_attention.py: MQA coverage (shape, cache, gradients, masking, dropout, raise)
- All modules and tests follow strict quality/documentation standards, code is now robust for research & production
2025-10-21 15:12:45 +03:00
Sergey Penkovsky
cfb4b6dfb1 feat(gemma): initial implementation of Gemma model and configs
- Add core Gemma model (architecture, attention, GeGLU, RoPE, RMSNorm, etc)
- Add configs for training and generation: gemma_train.json, gemma_generate.json
- Add Gemma notebook for exploratory analysis and demonstration
- Add __init__.py for Gemma submodule
- Update run_llm_experiment.py to support Gemma experiment configs

test(gemma): add comprehensive unit tests for Gemma

- Test forward pass (with/without cache)
- Test autoregressive generation (greedy, top-k, top-p)
- Test shape correctness and max sequence length errors
- Test multi-layer stack and token embeddings

docs: add documentation notebook for Gemma usage and analysis

Closes: #issue (if applicable)
2025-10-21 01:02:15 +03:00
Sergey Penkovsky
58c4a00b48 Merge pull request #4 from pese-git/feature/mixtral
Feature/mixtral
2025-10-20 16:36:39 +03:00
Sergey Penkovsky
c9da4c841b feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests
- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples
- Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components
- Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture
- Add thorough unit tests:
   * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc.
   * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check
- All code and tests in compliance with recent scientific and engineering standards.
2025-10-20 16:07:51 +03:00
Sergey Penkovsky
b1737bbce2 feat(mixtral): initial implementation of Mixtral MoE model, configs, and tests
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py)
- Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py)
- Create dedicated configuration files for Mixtral training and generation experiments
- Register and test Mixtral support in experiment runner (run_llm_experiment.py)
- Add unit tests for Mixtral API including forward, caching, and generation modes
- Include Jupyter notebook mixstral.ipynb for architectural exploration and research
- Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation

BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
2025-10-20 08:12:11 +03:00
Sergey Penkovsky
1aba02cab9 Merge pull request #3 from pese-git/feature/mistral
Feature/mistral
2025-10-17 20:45:20 +03:00
Sergey Penkovsky
9794db3e18 docs(readme): update project documentation for LLaMA, Mistral, HF integration
- Added explicit support and usage examples for Mistral and LLaMA architectures in both root and llm/ READMEs
- Updated directory structure and naming (datasets, tokenizers, mistral, hf-proxy)
- Clarified quickstart and experiments usage including config location and CLI
- Documented HuggingFace integration via  and marked it as experimental
- Highlighted differences and specifics of all supported architectures
- Improved guide for launching training/generation/experiments
- Made project scope and architecture more transparent for new contributors
2025-10-17 20:18:57 +03:00
Sergey Penkovsky
d947b7beb3 update and expand scientific docstrings for optimizer, scheduler, trainer
- Expanded module-level and function/class docstrings in optimizer.py, scheduler.py, and trainer.py
- Described mathematical foundations, theoretical motivations, and provided detailed usage examples for students
- All docstrings in Russian, clear scientific style

test(training): add comprehensive tests for optimizer, scheduler, and trainer modules

- Added new test files for get_optimizer, get_linear_schedule_with_warmup, and Trainer
- Tests cover parameter handling, edge cases, and expected learning dynamics (lr schedules and loss behavior)
- Trainer now logs average epoch losses to self.loss_history for testability and analysis

refactor(training/trainer): log epoch loss to loss_history for downstream analysis and tests

BREAKING CHANGE: Trainer.loss_history is a new attribute consolidating average losses per epoch, enabling robust learning dynamics assertions in tests
2025-10-17 16:25:39 +03:00
Sergey Penkovsky
613d784565 doc(datasets): update docstrings and tests 2025-10-17 10:49:45 +03:00
Sergey Penkovsky
38c271ca3c docs(models): update and expand docstrings for Mistral and its methods
- docs: add comprehensive docstrings for the Mistral class (in Russian) and its methods (forward, generate)
- docs: explain model architecture (GQA, Sliding Window Attention, SwiGLU, RMSNorm, RoPE), arguments, constraints, generation modes, usage examples, and references (Mistral, nucleus sampling)
- strictly documentation improvements, no logic/API changes

This commit makes Mistral model documentation clear and user-friendly for LLM engineering and inference.
2025-10-16 17:03:06 +03:00
Sergey Penkovsky
aec3c8adb6 docs(models): update and expand docstrings for LLaMA and generate method
- docs: add full, detailed Russian-language docstring for LLaMA.generate (sampling, top-k/top-p, examples, all parameter constraints and references)
- docs: bring LLaMA class header in line with modern LLM doc practices (motivation, architecture, references)
- no changes to logic, API, or tests

This makes the LLaMA model documentation fully transparent for all generation and inference modes.
2025-10-16 16:55:14 +03:00
Sergey Penkovsky
90eb2f4467 docs(models): expand docstring for generate method in GPT2
- docs: add detailed Russian-language docstring for generate method (args, nuances, sampling modes, error handling, usage examples, references to nucleus sampling and GPT-2 paper)
- strictly doc improvements, no logic or API changes

The updated documentation helps users clearly understand all generation options, constraints, and application modes in GPT2 LLMs.
2025-10-16 16:43:27 +03:00
Sergey Penkovsky
a3415d404a docs(models): update References in GPT docstring for vanilla implementation
- docs: update and focus References in GPT model docstring to only original GPT-1 (Radford et al., 2018) and BPE/Attention Is All You Need, removing GPT-2/HuggingFace links
- no changes to logic, API, or tests

This makes the documentation accurate for the vanilla GPT architecture and research lineage.
2025-10-16 16:33:53 +03:00
Sergey Penkovsky
9837ea3c3d docs(tokenizer): expand docstrings for BpeTokenizer
- docs: update and clarify docstrings for BpeTokenizer class and main methods (encode, decode)
- explain BPE algorithm, motivation, architecture, detailed usage examples, implementation details, references to original papers and major LLMs
- strictly doc improvements, no logic/API changes

This update makes tokenizer code easier to understand and use for language modeling research and engineering.
2025-10-16 15:26:17 +03:00
Sergey Penkovsky
baafca0546 docs(core): update docstrings for TokenEmbeddings
- docs: expand, clarify, and modernize docstrings for TokenEmbeddings class and its methods (__init__, forward, properties)
- explain layer purpose, motivation, math, parameter details, usage examples, and references
- no logic/API changes

This makes the input embedding code more accessible and maintainable for transformer and LLM development.
2025-10-16 15:14:53 +03:00
Sergey Penkovsky
516f9580fb docs(core): add docstrings and unit tests for SwiGLU block
- docs: rewrite and expand docstrings for SwiGLU class and forward method (motivation, math, architecture, usage, references to LLaMA/Mistral/PaLM)
- test: add unit tests for SwiGLU (shape, dtype, gradients, output range, fp16 support, reproducibility)
- strictly doc/tests, no logic or API changes

This improves transparency and reliability for gated FFN blocks in transformer architectures.
2025-10-16 15:09:09 +03:00
Sergey Penkovsky
64d33783e0 docs(core): add docstrings and unit tests for SiLU activation
- docs: expand and clarify docstrings for SiLU class and its method (mathematical formula, motivation, properties vs ReLU/GELU, usage, and references to Swish/LLM papers)
- test: add unit tests for SiLU (shape/dtype, behavior on large/small values, PyTorch reference, gradients, broadcast)
- no logic/API changes

This update improves reliability and usability of the SiLU activation module.
2025-10-16 14:48:50 +03:00
Sergey Penkovsky
6efc946027 docs(core): expand docstrings and add unit tests for RMSNorm
- docs: update/increase docstring detail for RMSNorm class and methods (motivation, formula, architecture, usage, references to LLaMA/PaLM/GPT)
- test: add comprehensive unit tests for RMSNorm (shape/type preservation, rms scaling, gradients for input and weights, fp16, large eps stability)

No code/API changes beyond docs and new tests.
2025-10-16 14:37:25 +03:00
Sergey Penkovsky
8018efae2a docs(core): expand docstrings for PositionalEmbeddings module
- docs: update and clarify docstrings for PositionalEmbeddings class and methods (__init__, forward)
- explain motivation, mathematical formulas, usage examples, architectural options (learned vs sinusoidal), external references
- no API or code changes

This makes the positional encoding component easier to understand and use for all transformer practitioners.
2025-10-16 14:09:05 +03:00
Sergey Penkovsky
0832d78acf docs(core): improve docstrings and add unit tests for GELU activation
- docs: rewrite and expand docstrings for GELU class and method (motivation, math formula, smoother ReLU for Transformers, usage, references)
- test: add dedicated tests for GELU (output shape, dtype, comparison with torch GELU, monotonicity, gradients, large/small value behavior)
- fix: align numerical test to allow for minor approximation difference vs PyTorch gelu

This update makes the GELU module more transparent and robust for deep learning practitioners and researchers.
2025-10-16 13:59:38 +03:00
Sergey Penkovsky
c338556cfe docs(core): improve and expand docstrings for FeedForward module
- docs: rewrite and clarify docstrings for FeedForward class and its methods (__init__, forward) with architectural explanation, pseudocode, motivation, parameter details, usage example, and key references (GELU, SwiGLU, Transformer)
- no changes to logic or APIs

This makes the feed-forward block more transparent for users and researchers working with transformer models.
2025-10-16 12:47:47 +03:00
Sergey Penkovsky
3a356f5d79 docs(core): improve and expand docstrings for Decoder module
- docs: rewrite and expand docstrings for Decoder class and its methods (__init__, forward)
- clarify the block’s architecture, pre-LN logic, flow with residual connections, and attention masking
- add mathematical pseudocode, motivation, feature list, usage example, and external references (papers, blog)
- no logic or behavior changes

This improves readability and makes the codebase easier to understand for transformer/LLM practitioners.
2025-10-16 12:40:46 +03:00
Sergey Penkovsky
923aa51e2a docs(core): add docstrings and unit tests for CachedDecoder module
- docs: Add detailed docstrings for CachedDecoder class and its methods (__init__, forward); explain autoregressive caching, architecture, math, usage, and links to GPT-2/LLM references
- test: Add comprehensive unit tests for CachedDecoder (initialization, forward with and without cache, cache chaining, output shape, error on long input, backward pass)
- These changes improve code clarity, reliability, and testing for decoder blocks with KV cache.
2025-10-16 12:30:53 +03:00
Sergey Penkovsky
ba3b04cec2 docs(core): add docstrings and unit tests for MistralDecoder
- docs: expanded docstrings for MistralDecoder class and methods (__init__, forward); explained architecture, key parameters, usage, and links to relevant papers (Mistral, Llama 2)
- test: add comprehensive unit tests for MistralDecoder (init, forward, cache handling, output shape, shape errors, backward)
- These changes improve explainability, reliability, and test coverage for the decoder module.
2025-10-15 18:07:11 +03:00
Sergey Penkovsky
e6ca8dee6f docs(core): add comprehensive docstrings and unit tests for GroupedQueryAttention (GQA)
- docs: Rewrite and expand docstrings for the GroupedQueryAttention class and all main methods (__init__, forward, _repeat_kv_heads, _create_sliding_window_mask):
    - explained GQA architecture and motivation
    - included mathematical formulas, step-by-step algorithms, usage examples
    - added references to relevant scientific papers (Mistral, Llama 2, etc.)
- test: Add dedicated unit tests for GQA (output shape correctness, mask/window logic, KV head replication, RoPE processing, error and edge-cases)
- docs/test: Documentation and tests now fully reflect modern GQA usage and best practices for LLM architectures

This commit makes the implementation, usage, and theoretical underpinnings of GQA transparent and reproducible for researchers and engineers.
2025-10-15 17:27:55 +03:00
Sergey Penkovsky
2e72dbaf07 test(llama): add unit tests for generation, cache, and edge cases
- Covers inference with and without cache and with sampling (top-k, top-p)
- Includes test for max sequence length (should raise ValueError)
- Verifies output shape and absence of dtype errors for the mask logic
- Minimal config and random data ensure tests are fast and robust

Motivation: Regression and integration protection for Llama decoding and sampling logic.
2025-10-15 14:37:35 +03:00
Sergey Penkovsky
dc440a3938 test(gpt2): add unit tests for generation, cache behavior, and error conditions
- Covers forward pass with and without KV-cache
- Verifies correct sequence generation for greedy, top-k, and top-p sampling
- Adds ValueError test for exceeding max sequence length
- Uses small random toy config and minimal setup for fast test feedback

Motivation: Prevent regressions in decoding, sampling, and KV-cache logic in GPT2 implementation.
2025-10-15 14:36:32 +03:00
Sergey Penkovsky
50d7593023 fix(gpt2, llama): proper top-k/top-p mask handling in sampling for PyTorch compatibility (bool/uint8)
- Refactored token selection logic in  methods of GPT2 and Llama classes.
- Masks are now created with dtype=torch.bool (or torch.uint8 for legacy PyTorch).
- Used True/False for mask/scatter instead of 1/0, ensuring correctness across PyTorch versions.
- Fixed RuntimeError: masked_fill_ only supports boolean masks, previously raised by uint8-masks in new PyTorch.
- Backward compatibility maintained: code works on PyTorch >=1.2 and for old clusters (via the else branch).

Motivation: Fixes sampling errors for all modern PyTorch users while keeping research code usable on old infra.
2025-10-15 14:35:10 +03:00
Sergey Penkovsky
38682e8c9d test(mistral): add unit tests for model generation and cache 2025-10-15 13:20:50 +03:00
Sergey Penkovsky
e791f7cd93 fix(mistral): fix top-k/top-p mask handling for PyTorch >=1.2 2025-10-15 13:20:30 +03:00
Sergey Penkovsky
d10044e4a7 refactor(core): refactor RoPE and MultiHeadAttention, add math-rich docs, expand tests, remove unused head_attention
- refactor: улучшена и унифицирована реализация RoPE, теперь поддерживаются строгие проверки размерности входа; внесены улучшения и структурные изменения в MultiHeadAttention (более понятная логика, строгая спецификация входов/выходов)
- docs: полностью переписаны docstrings для RoPE и MultiHeadAttention — включены математические формулы, ссылки на научные статьи, подробные пояснения по алгоритму, формату входных данных, ограничениям, примеры использования
- test: добавлены отдельные unit-тесты для RoPE (корректность формы, ошибки на неверную размерность, сохранение нормы, backward/градиенты, работу с параметрами start_pos и батчами)
- chore: удалён неиспользуемый модуль core/head_attention.py
- fix: теперь выбрасывается AssertionError при неправильной размерности входа RoPE; это позволило полностью покрыть тест-кейсы на ошибки

Этот коммит синхронизирует логику реализации базового внимания с современной практикой LLM, укрепляет документацию для инженеров и исследователей, а также расширяет надежность автотестирования библиотеки.
2025-10-15 11:04:07 +03:00
Sergey Penkovsky
ec0d2bd8d0 feat(mistral): add Mistral model implementation and configs
- implement Mistral model in llm/models/mistral/mistral.py with GroupedQueryAttention, SwiGLU, RoPE, sliding window attention
- add __init__.py for module export
- add config files for mistral training and generation
- update universal experiment runner to support Mistral model
- add notebook for Mistral experiments
2025-10-14 14:53:45 +03:00
Sergey Penkovsky
e5706a690d fix(rope, attention): корректное позиционирование RoPE при генерации с кэшем
- Исправлена ошибка расчёта позиции для RoPE (Rotary Positional Embeddings) при автодополнении с использованием кэша.
- В HeadAttention теперь передаётся start_pos в RoPE, вычисляемый из длины кэша.
- Обновлена сигнатура и логика метода RoPE.forward.
- Обновлен ноутбук llama.ipynb под новые интерфейсы и выводы.

BREAKING CHANGE: переопределён метод forward у RoPE, требуется обновить код, если RoPE использовался вручную.
2025-10-14 12:03:20 +03:00
Sergey Penkovsky
3e4815fcc6 refactor(experiments): migrate to universal runner + config structure, remove legacy scripts
- add universal runner run_llm_experiment.py with JSON-config driven LLM training / generation
- add configs for gpt, gpt2, llama (training/generation)
- remove individual train/generate scripts for each model
- update README with simple how-to for experiments block

BREAKING CHANGE: all llm_only experiments now run only through run_llm_experiment.py; legacy scripts removed
2025-10-14 11:57:23 +03:00
Sergey Penkovsky
0cc7850848 fix: format code 2025-10-06 23:03:01 +03:00
Sergey Penkovsky
237b86421e doc: update docstring 2025-10-06 23:02:03 +03:00
Sergey Penkovsky
712278e33c Рефакторинг: единообразие оформления кода (пробелы, кавычки, пустые строки), без изменения логики по всему проекту. 2025-10-06 22:57:19 +03:00
Sergey Penkovsky
332cad6159 Merge pull request #2 from pese-git/feature/llama
Feature/llama
2025-10-06 22:05:45 +03:00
121 changed files with 15593 additions and 4344 deletions

View File

@@ -1,15 +1,16 @@
# LLM Architecture Research
Исследовательский проект для разработки и обучения архитектур больших языковых моделей (LLM).
Исследовательский проект по разработке, обучению и сравнительному анализу современных архитектур больших языковых моделей (LLM): **GPT, GPT-2, LLaMA, Mistral**. Прямая поддержка интеграции с HuggingFace (через модуль `hf-proxy`).
## 🏗️ Архитектура проекта
Проект организован как монорепозиторий с использованием **uv** workspace:
- **`llm`** — основная библиотека с реализацией архитектур LLM (GPT, GPT-2)
- **`hf-proxy`** — адаптер для интеграции с HuggingFace
- **`experiments`** — скрипты обучения и экспериментов
- **`notebooks`** — исследовательские ноутбуки
- **`llm`** — основная библиотека с реализацией архитектур LLM (**GPT, GPT-2, LLaMA, Mistral**)
- **`hf-proxy`** — экспериментальный адаптер для интеграции с HuggingFace (загрузка, токенизация, экспериментальные скрипты). Функционал может изменяться и не гарантирует полной совместимости с будущими версиями HuggingFace Transformers.
- **`experiments`** — скрипты обучения и генерации (включая HF и собственные модели)
- **`notebooks`** — исследовательские ноутбуки, анализ архитектур
## 📁 Структура проекта
@@ -41,8 +42,11 @@ llm-arch-research/
│ │ │ ├── gpt.py
│ │ │ ├── gpt2.py
│ │ │ └── __init__.py
│ │ ── llama/ # LLaMA архитектура
│ │ ├── llama.py
│ │ ── llama/ # LLaMA архитектура
│ │ ├── llama.py
│ │ │ └── __init__.py
│ │ └── mistral/ # Mistral архитектура
│ │ ├── mistral.py
│ │ └── __init__.py
│ ├── training/ # утилиты обучения
│ │ ├── dataset.py
@@ -81,6 +85,18 @@ llm-arch-research/
## 🚀 Быстрый старт
**Пример запуска обучения и генерации для любых архитектур:**
```bash
python experiments/llm_only/run_llm_experiment.py --model mistral --action generate --config experiments/llm_only/configs/mistral_generate.json
```
**Использование собственных моделей с HuggingFace-интерфейсом:**
```python
from hf_proxy.hf_adapter import HFAdapter
hf_model = HFAdapter("mistralai/Mistral-7B-v0.1")
```
### Установка зависимостей
```bash
@@ -91,15 +107,17 @@ uv sync
uv sync --extra dev
```
### Запуск обучения GPT
## ⚡ Работа с экспериментами (experiments/llm_only, experiments/hf_integration)
```bash
# Обучение базовой GPT модели
uv run python experiments/llm_only/train_gpt_bpe.py
- В `experiments/llm_only`: универсальный скрипт для обучения и генерации LLM (включая LLaMA и Mistral) без HuggingFace — всё через собственную реализацию.
- В `experiments/hf_integration`: скрипты и примеры для генерации, обучения и тестирования моделей с помощью HuggingFace API (через hf-proxy). Позволяет использовать свои модели и токенизаторы как стандартные HF-объекты.
# Обучение с интеграцией HuggingFace
uv run python experiments/hf_integration/simple_hf_training.py
```
**Для моделей Mistral/Llama доступны оба сценария: прямая работа или через HuggingFace-прокси.**
*Конфиги и примеры см. в соответствующих папках.*
---
### Тестирование hf-proxy
@@ -212,33 +230,23 @@ dependencies = [
## 🎯 Реализованные возможности
### Архитектуры GPT и GPT-2
-Токенные и позиционные эмбеддинги
-Многоголовое внимание с causal mask
-Декодерные блоки с residual connections
-Layer normalization
- ✅ Dropout регуляризация
- ✅ Отдельные реализации GPT и GPT-2 (различия в масштабе и деталях архитектуры)
### Архитектуры
-GPT, GPT-2: Полностью воспроизводимые реализации, токенные и позиционные эмбеддинги, causal multi-head attention, LayerNorm
-LLaMA: Rotary Positional Embeddings (RoPE), RMSNorm, SwiGLU, оптимизированная память
-Mistral: Sliding Window Attention (оконное внимание), Grouped Query Attention (GQA), совместимость с HF
-Все архитектуры поддерживают обучение и генерацию текста
### Генерация текста
-Жадный поиск (greedy decoding)
- ✅ Вероятностное сэмплирование
- ✅ Top-k сэмплирование
- ✅ Nucleus sampling (top-p)
- ✅ Контроль температуры
-Greedy, sampling (Top-k, Top-p), контроль температуры, efficient caching
### Обучение
-Датасет для языкового моделирования
-Базовый тренировочный цикл
- ✅ Оптимизатор AdamW
- ✅ Сохранение чекпоинтов
-Языковое моделирование с кастомными и HF-токенизаторами
-AdamW, кастомные датасеты, сохранение чекпоинтов
### Интеграция с HuggingFace (hf-proxy)
-Адаптер моделей для совместимости с HF интерфейсами
-Адаптер токенизаторов с поддержкой всех методов HF
-Сохранение и загрузка в HF формате
- ✅ Совместимость с HF Trainer и pipelines
- ✅ Генерация через стандартные HF интерфейсы
-Экспорт/импорт моделей и токенизаторов в HF совместимый формат
-Генерация и обучение через HF Trainer, pipelines и т.д.
-Двусторонняя поддержка: собственные модели становятся HF-совместимыми и наоборот
## 🔬 Эксперименты с hf-proxy

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" vertex="1" connectable="0" parent="1">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,413 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="2176" dy="702" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;container=0;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="281.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="580" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="490.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="609.9976190476191" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="407.1428571428571" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="250" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="385.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="380.00428571428574" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="443.80952380952385" y="410" as="sourcePoint"/>
<mxPoint x="375.7142857142858" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="466.8571428571429" y="80" as="sourcePoint"/>
<mxPoint x="579.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="466.67904761904765" y="125"/>
<mxPoint x="522.3809523809523" y="125"/>
<mxPoint x="585" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="580.0019047619048" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="264.3255813953488" y="80" as="sourcePoint"/>
<mxPoint x="380.7142857142858" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="264.2585825027686" y="130"/>
<mxPoint x="319.96048726467325" y="130"/>
<mxPoint x="385" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="141" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="92">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="180" y="250" as="sourcePoint"/>
<mxPoint x="281" y="110" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="142" value="" style="endArrow=none;dashed=1;html=1;entryX=1;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" target="4">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="620" y="560" as="sourcePoint"/>
<mxPoint x="660" y="520" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="218" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
<mxGeometry x="130" y="660" width="680" height="160" as="geometry"/>
</mxCell>
<mxCell id="195" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="147">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="196" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="148">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="197" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="149">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="143" value="X" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="147" value="W&lt;sub&gt;k&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="199" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="148" target="151">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="148" value="W&lt;sub&gt;q&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="149" value="W&lt;sub&gt;v&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" y="120" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="207" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="150" target="158">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="236" y="20"/>
<mxPoint x="236" y="50"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="150" value="K" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" width="41.03" height="40" as="geometry"/>
</mxCell>
<mxCell id="208" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="151" target="190">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="151" value="Q" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" y="60" width="41.03" height="40" as="geometry"/>
</mxCell>
<mxCell id="214" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="152" target="187">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="600" y="140"/>
<mxPoint x="600" y="80"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="152" value="V" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" y="120" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="215" style="edgeStyle=none;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="187">
<mxGeometry relative="1" as="geometry">
<mxPoint x="680" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="187" value="O" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="620" y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="211" style="edgeStyle=none;html=1;entryX=0;entryY=0;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="188" target="179">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="188" value="Scale" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#f8cecc;strokeColor=#b85450;" vertex="1" parent="218">
<mxGeometry x="370" y="37.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="213" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="189" target="187">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="600" y="50"/>
<mxPoint x="600" y="80"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="189" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="218">
<mxGeometry x="530" y="37.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="209" style="edgeStyle=none;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="190" target="188">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="190" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
<mxGeometry x="272.5" y="10" width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="153" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="154" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="155" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="156" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="157" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="158" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="159" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="160" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="161" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="162" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="163" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="164" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="165" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="166" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="167" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="168" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="169" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="191" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
<mxGeometry x="440" y="10" width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="170" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="171" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="172" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="173" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="174" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="175" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="176" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="177" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="178" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="179" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="180" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="181" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="182" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="183" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="184" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="185" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="186" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="198" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="147" target="150">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="200" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" target="152">
<mxGeometry relative="1" as="geometry">
<mxPoint x="120" y="140" as="sourcePoint"/>
<mxPoint x="148.97000000000008" y="139.8599999999998" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="212" value="" style="endArrow=classic;html=1;exitX=1;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="182" target="189">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="610" y="50" as="sourcePoint"/>
<mxPoint x="660" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="219" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
<mxGeometry x="289.99776556776555" y="520" width="250.00223443223445" height="90" as="geometry"/>
</mxCell>
<mxCell id="145" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="133" target="144">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="133" value="Concat" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
<mxGeometry x="132.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="136" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" target="133">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="108.97435897435912" y="45" as="sourcePoint"/>
<mxPoint x="250.00223443223445" y="25" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="129" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="30" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="130" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="20" y="10" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="131" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="10" y="20" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="132" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry y="30" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="146" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="144">
<mxGeometry relative="1" as="geometry">
<mxPoint x="250.00223443223445" y="44.969696969697" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="144" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
<mxGeometry x="182.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="220" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="90" y="690" as="sourcePoint"/>
<mxPoint x="290" y="610" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="221" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="370" y="610" as="sourcePoint"/>
<mxPoint x="830" y="700" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="979" dy="301" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;strokeColor=#FF3333;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="91" value="" style="group;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="360" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="56" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="57" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="58" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="59" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="59" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="60" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="61" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="62" target="59" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="62" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="63" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="56" target="57" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="64" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="65" target="62" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="65" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="66" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="57" target="65" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="67" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="69" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="68" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="69" target="60" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="69" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="70" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="65" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="71" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="72" target="56" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="72" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#FF3333;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="73" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="74" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="75" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="76" target="79" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="76" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="77" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="60" target="76" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="78" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="79" target="85" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="79" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="80" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="81" target="83" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="81" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="82" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="83" target="87" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="83" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="84" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="85" target="81" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="85" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="86" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="87" target="88" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="87" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="88" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="89" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="90" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="89" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,192 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="2176" dy="1029" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="107" value="" style="group" vertex="1" connectable="0" parent="1">
<mxGeometry x="120" y="170" width="1286" height="265" as="geometry"/>
</mxCell>
<mxCell id="92" value="" style="group" parent="107" vertex="1" connectable="0">
<mxGeometry width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="104" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="3">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="200" y="200" as="sourcePoint"/>
<mxPoint x="260.96000000000004" y="110" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="105" value="" style="endArrow=none;dashed=1;html=1;exitX=1;exitY=1;exitDx=0;exitDy=0;" edge="1" parent="107" source="5">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="660" y="140" as="sourcePoint"/>
<mxPoint x="620" y="190" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="106" value="" style="group" vertex="1" connectable="0" parent="107">
<mxGeometry x="450" y="195" width="170" height="70" as="geometry"/>
</mxCell>
<mxCell id="100" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="93" target="99">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="93" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
<mxGeometry x="-5" y="20" width="70" height="30" as="geometry"/>
</mxCell>
<mxCell id="96" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="106" target="93">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint y="35" as="sourcePoint"/>
<mxPoint y="35.00999999999999" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="102" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;" edge="1" parent="106" source="98">
<mxGeometry relative="1" as="geometry">
<mxPoint x="170" y="35.09433962264154" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="98" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
<mxGeometry x="100" y="20" width="70" height="30" as="geometry"/>
</mxCell>
<mxCell id="101" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="99" target="98">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="99" value="ReLU" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="106">
<mxGeometry x="50" y="20" width="70" height="30" as="geometry"/>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -14,12 +14,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from hf_proxy import HFAdapter, HFTokenizerAdapter, create_hf_pipeline
from shared.configs import (
TEST_PROMPTS, GENERATION_CONFIG, PATHS
)
from shared.data import (
print_experiment_info, ensure_directories, ExperimentLogger
)
from shared.configs import TEST_PROMPTS, GENERATION_CONFIG, PATHS
from shared.data import print_experiment_info, ensure_directories, ExperimentLogger
def load_hf_model_and_tokenizer() -> tuple:
@@ -41,9 +37,7 @@ def load_hf_model_and_tokenizer() -> tuple:
)
if not os.path.exists(tokenizer_path):
raise FileNotFoundError(
f"Токенизатор не найден: {tokenizer_path}"
)
raise FileNotFoundError(f"Токенизатор не найден: {tokenizer_path}")
# Загружаем адаптированный токенизатор
print("🔧 Загрузка адаптированного токенизатора...")
@@ -52,8 +46,9 @@ def load_hf_model_and_tokenizer() -> tuple:
# Загружаем конфигурацию модели
import json
config_path = os.path.join(model_path, "config.json")
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, "r", encoding="utf-8") as f:
model_config = json.load(f)
# Загружаем модель через HFAdapter с правильной конфигурацией
@@ -62,6 +57,7 @@ def load_hf_model_and_tokenizer() -> tuple:
# Создаем конфигурацию из сохраненного config.json
from hf_proxy import HFAdapterConfig
hf_config = HFAdapterConfig(
vocab_size=model_config["vocab_size"],
hidden_size=model_config["hidden_size"],
@@ -69,7 +65,9 @@ def load_hf_model_and_tokenizer() -> tuple:
num_attention_heads=model_config["num_attention_heads"],
max_position_embeddings=model_config["max_position_embeddings"],
hidden_dropout_prob=model_config.get("hidden_dropout_prob", 0.1),
attention_probs_dropout_prob=model_config.get("attention_probs_dropout_prob", 0.1),
attention_probs_dropout_prob=model_config.get(
"attention_probs_dropout_prob", 0.1
),
)
hf_model = HFAdapter.from_pretrained(model_bin_path, hf_config=hf_config)
@@ -97,7 +95,7 @@ def test_hf_pipeline(hf_model, hf_tokenizer):
device="cpu",
max_length=50,
do_sample=True,
temperature=0.7
temperature=0.7,
)
print("✅ HuggingFace pipeline создан")
@@ -132,8 +130,10 @@ def generate_with_hf_model(hf_model, hf_tokenizer, prompt: str, config: dict) ->
str: Сгенерированный текст
"""
print(f"🔤 Промпт: '{prompt}'")
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
f"temp={config['temperature']}, sample={config['do_sample']}")
print(
f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
f"temp={config['temperature']}, sample={config['do_sample']}"
)
# Кодируем через адаптированный токенизатор
inputs = hf_tokenizer(prompt, return_tensors="pt")
@@ -144,12 +144,12 @@ def generate_with_hf_model(hf_model, hf_tokenizer, prompt: str, config: dict) ->
# Генерируем через адаптированную модель
with torch.no_grad():
generated_ids = hf_model.generate(
input_ids=inputs['input_ids'],
input_ids=inputs["input_ids"],
max_new_tokens=config["max_new_tokens"],
do_sample=config["do_sample"],
temperature=config["temperature"],
top_k=config["top_k"],
top_p=config["top_p"]
top_p=config["top_p"],
)
# Декодируем через адаптированный токенизатор
@@ -174,23 +174,29 @@ def test_different_hf_strategies(hf_model, hf_tokenizer, prompt: str):
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
{
"name": "❄️ Детерминированная (temp=0.3)",
"do_sample": True,
"temperature": 0.3,
},
]
for strategy in strategies:
print(f"\n{strategy['name']}:")
try:
config = GENERATION_CONFIG.copy()
config.update({
"do_sample": strategy["do_sample"],
"temperature": strategy["temperature"],
"max_new_tokens": 20
})
config.update(
{
"do_sample": strategy["do_sample"],
"temperature": strategy["temperature"],
"max_new_tokens": 20,
}
)
generated = generate_with_hf_model(hf_model, hf_tokenizer, prompt, config)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
generated_part = generated[len(prompt) :]
print(f" 📤 Промпт: '{prompt}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
print(f" 📄 Полный текст: '{generated}'")
@@ -215,7 +221,7 @@ def analyze_hf_tokenization(hf_tokenizer, texts: list):
# Токенизация через адаптер
inputs = hf_tokenizer(text, return_tensors="pt")
tokens = inputs['input_ids'].tolist()[0]
tokens = inputs["input_ids"].tolist()[0]
token_strings = hf_tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
@@ -247,7 +253,7 @@ def interactive_hf_generation(hf_model, hf_tokenizer):
try:
user_input = input("\n🔤 Введите промпт: ").strip()
if user_input.lower() in ['exit', 'quit', 'выход']:
if user_input.lower() in ["exit", "quit", "выход"]:
break
if not user_input:
@@ -258,7 +264,7 @@ def interactive_hf_generation(hf_model, hf_tokenizer):
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
do_sample = do_sample_input != 'n'
do_sample = do_sample_input != "n"
except:
max_tokens = 50
temperature = 0.7
@@ -266,15 +272,19 @@ def interactive_hf_generation(hf_model, hf_tokenizer):
print("⚠️ Использую параметры по умолчанию")
config = GENERATION_CONFIG.copy()
config.update({
"max_new_tokens": max_tokens,
"temperature": temperature,
"do_sample": do_sample
})
config.update(
{
"max_new_tokens": max_tokens,
"temperature": temperature,
"do_sample": do_sample,
}
)
generated = generate_with_hf_model(hf_model, hf_tokenizer, user_input, config)
generated = generate_with_hf_model(
hf_model, hf_tokenizer, user_input, config
)
generated_part = generated[len(user_input):]
generated_part = generated[len(user_input) :]
print(f"\n🎯 Результат:")
print(f" 📤 Промпт: '{user_input}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
@@ -295,7 +305,7 @@ def main():
"model": "GPT через HFAdapter",
"tokenizer": "BPE через HFTokenizerAdapter",
"инструменты": "HuggingFace pipeline & генерация",
"стратегия": "интеграция с HF экосистемой"
"стратегия": "интеграция с HF экосистемой",
}
print_experiment_info(experiment_name, experiment_config)
@@ -310,7 +320,7 @@ def main():
analysis_texts = [
"Искусственный интеллект",
"Нейронные сети",
"Машинное обучение"
"Машинное обучение",
]
analyze_hf_tokenization(hf_tokenizer, analysis_texts)
@@ -326,10 +336,12 @@ def main():
print("-" * 40)
try:
generated = generate_with_hf_model(hf_model, hf_tokenizer, prompt, GENERATION_CONFIG)
generated = generate_with_hf_model(
hf_model, hf_tokenizer, prompt, GENERATION_CONFIG
)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
generated_part = generated[len(prompt) :]
print(f"📤 Промпт: '{prompt}'")
print(f"🎯 Сгенерировано: '{generated_part}'")
@@ -365,6 +377,7 @@ def main():
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()

View File

@@ -19,8 +19,12 @@ from llm.tokenizers import BPETokenizer
from hf_proxy import HFAdapter, HFTokenizerAdapter
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TRAINING_CONFIG, PATHS, TEST_PROMPTS
TRAIN_TEXTS,
BASE_GPT_CONFIG,
BPE_CONFIG,
TRAINING_CONFIG,
PATHS,
TEST_PROMPTS,
)
@@ -45,18 +49,15 @@ def create_dataset(hf_tokenizer, texts, max_length=128):
max_length=max_length,
truncation=True,
padding=False,
return_tensors="pt"
return_tensors="pt",
)
input_ids = inputs['input_ids'][0]
input_ids = inputs["input_ids"][0]
# Создаем метки для языкового моделирования
labels = input_ids.clone()
dataset.append({
'input_ids': input_ids,
'labels': labels
})
dataset.append({"input_ids": input_ids, "labels": labels})
return dataset
@@ -84,10 +85,7 @@ def manual_training_loop(hf_model, hf_tokenizer, train_texts, val_texts, config)
print(f"📊 Данные: {len(train_dataset)} train, {len(val_dataset)} validation")
# Оптимизатор
optimizer = torch.optim.AdamW(
hf_model.parameters(),
lr=config["learning_rate"]
)
optimizer = torch.optim.AdamW(hf_model.parameters(), lr=config["learning_rate"])
# Функция потерь
loss_fn = nn.CrossEntropyLoss()
@@ -105,8 +103,8 @@ def manual_training_loop(hf_model, hf_tokenizer, train_texts, val_texts, config)
for i, batch in enumerate(train_dataset):
optimizer.zero_grad()
input_ids = batch['input_ids'].unsqueeze(0) # [1, seq_len]
labels = batch['labels'].unsqueeze(0) # [1, seq_len]
input_ids = batch["input_ids"].unsqueeze(0) # [1, seq_len]
labels = batch["labels"].unsqueeze(0) # [1, seq_len]
# Forward pass
outputs = hf_model(input_ids=input_ids, labels=labels)
@@ -130,8 +128,8 @@ def manual_training_loop(hf_model, hf_tokenizer, train_texts, val_texts, config)
epoch_val_loss = 0
with torch.no_grad():
for batch in val_dataset:
input_ids = batch['input_ids'].unsqueeze(0)
labels = batch['labels'].unsqueeze(0)
input_ids = batch["input_ids"].unsqueeze(0)
labels = batch["labels"].unsqueeze(0)
outputs = hf_model(input_ids=input_ids, labels=labels)
epoch_val_loss += outputs.loss.item()
@@ -143,10 +141,10 @@ def manual_training_loop(hf_model, hf_tokenizer, train_texts, val_texts, config)
hf_model.train()
return {
'train_losses': train_losses,
'val_losses': val_losses,
'final_train_loss': train_losses[-1],
'final_val_loss': val_losses[-1]
"train_losses": train_losses,
"val_losses": val_losses,
"final_train_loss": train_losses[-1],
"final_val_loss": val_losses[-1],
}
@@ -170,10 +168,10 @@ def test_generation_after_training(hf_model, hf_tokenizer, test_prompts):
with torch.no_grad():
generated = hf_model.generate(
input_ids=inputs['input_ids'],
input_ids=inputs["input_ids"],
max_new_tokens=20,
do_sample=True,
temperature=0.8
temperature=0.8,
)
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
@@ -192,7 +190,9 @@ def main():
try:
# === Подготовка данных ===
print("🔧 Подготовка данных...")
train_texts = TRAIN_TEXTS[:10] # Используем меньше данных для быстрого тестирования
train_texts = TRAIN_TEXTS[
:10
] # Используем меньше данных для быстрого тестирования
val_texts = TRAIN_TEXTS[10:12]
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
@@ -203,7 +203,7 @@ def main():
llm_tokenizer.train(
texts=train_texts,
vocab_size=BPE_CONFIG["vocab_size"],
special_tokens=BPE_CONFIG["special_tokens"]
special_tokens=BPE_CONFIG["special_tokens"],
)
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
@@ -227,7 +227,7 @@ def main():
training_config = {
"learning_rate": TRAINING_CONFIG["learning_rate"],
"num_epochs": 2, # Меньше эпох для быстрого тестирования
"batch_size": TRAINING_CONFIG["batch_size"]
"batch_size": TRAINING_CONFIG["batch_size"],
}
results = manual_training_loop(
@@ -255,20 +255,23 @@ def main():
# Сохраняем модель
HFAdapter.save_pretrained(
hf_model,
"checkpoints/hf_simple_trained",
tokenizer=hf_tokenizer
hf_model, "checkpoints/hf_simple_trained", tokenizer=hf_tokenizer
)
print("✅ Модель сохранена")
# Сохраняем результаты
results_path = "checkpoints/simple_training_results.json"
with open(results_path, 'w', encoding='utf-8') as f:
json.dump({
'training_config': training_config,
'model_config': model_config,
'results': results
}, f, indent=2, ensure_ascii=False)
with open(results_path, "w", encoding="utf-8") as f:
json.dump(
{
"training_config": training_config,
"model_config": model_config,
"results": results,
},
f,
indent=2,
ensure_ascii=False,
)
print(f"✅ Результаты сохранены в {results_path}")
print(f"\n🎉 Упрощенное обучение завершено успешно!")
@@ -278,6 +281,7 @@ def main():
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()

View File

@@ -16,8 +16,11 @@ from llm.tokenizers import BPETokenizer
from hf_proxy import HFAdapter, HFTokenizerAdapter
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TEST_PROMPTS, GENERATION_CONFIG
TRAIN_TEXTS,
BASE_GPT_CONFIG,
BPE_CONFIG,
TEST_PROMPTS,
GENERATION_CONFIG,
)
@@ -31,7 +34,7 @@ def test_basic_hf_integration():
llm_tokenizer.train(
texts=TRAIN_TEXTS,
vocab_size=BPE_CONFIG["vocab_size"],
special_tokens=BPE_CONFIG["special_tokens"]
special_tokens=BPE_CONFIG["special_tokens"],
)
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
@@ -62,7 +65,7 @@ def test_basic_hf_integration():
print(f" HF адаптер: {hf_inputs['input_ids'].shape}")
# Декодирование
decoded = hf_tokenizer.decode(hf_inputs['input_ids'][0])
decoded = hf_tokenizer.decode(hf_inputs["input_ids"][0])
print(f" Декодированный: '{decoded}'")
# === Тестирование forward pass ===
@@ -87,10 +90,10 @@ def test_basic_hf_integration():
with torch.no_grad():
generated = hf_model.generate(
input_ids=inputs['input_ids'],
input_ids=inputs["input_ids"],
max_new_tokens=10,
do_sample=True,
temperature=0.8
temperature=0.8,
)
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
@@ -123,7 +126,9 @@ def test_basic_hf_integration():
test_input = hf_tokenizer("Тест", return_tensors="pt")
with torch.no_grad():
loaded_outputs = loaded_model(**test_input)
print(f" ✅ Загруженная модель работает (logits: {loaded_outputs.logits.shape})")
print(
f" ✅ Загруженная модель работает (logits: {loaded_outputs.logits.shape})"
)
except Exception as e:
print(f" ❌ Ошибка сохранения/загрузки: {e}")
@@ -140,7 +145,7 @@ def test_hf_tokenizer_methods():
llm_tokenizer.train(
texts=TRAIN_TEXTS[:5],
vocab_size=500,
special_tokens=BPE_CONFIG["special_tokens"]
special_tokens=BPE_CONFIG["special_tokens"],
)
hf_tokenizer = HFTokenizerAdapter(llm_tokenizer)
@@ -199,6 +204,7 @@ def main():
except Exception as e:
print(f"\n❌ Ошибка в тестировании: {e}")
import traceback
traceback.print_exc()

View File

@@ -17,12 +17,18 @@ from llm.tokenizers import BPETokenizer
from hf_proxy import HFAdapter, HFTokenizerAdapter
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TRAINING_CONFIG, PATHS, TEST_PROMPTS
TRAIN_TEXTS,
BASE_GPT_CONFIG,
BPE_CONFIG,
TRAINING_CONFIG,
PATHS,
TEST_PROMPTS,
)
from shared.data import (
load_training_data, ensure_directories,
print_experiment_info, ExperimentLogger
load_training_data,
ensure_directories,
print_experiment_info,
ExperimentLogger,
)
@@ -50,7 +56,7 @@ def setup_hf_training():
llm_tokenizer.train(
texts=TRAIN_TEXTS,
vocab_size=BPE_CONFIG["vocab_size"],
special_tokens=BPE_CONFIG["special_tokens"]
special_tokens=BPE_CONFIG["special_tokens"],
)
llm_tokenizer.save(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор обучен и сохранен")
@@ -117,7 +123,7 @@ def main():
"tokenizer": "BPE через HFTokenizerAdapter",
"trainer": "HuggingFace Trainer",
"vocab_size": BPE_CONFIG["vocab_size"],
"training_epochs": TRAINING_CONFIG["num_epochs"]
"training_epochs": TRAINING_CONFIG["num_epochs"],
}
print_experiment_info(experiment_name, experiment_config)
@@ -126,7 +132,14 @@ def main():
try:
# Настраиваем обучение
hf_model, hf_tokenizer, llm_tokenizer, model_config, train_texts, val_texts = setup_hf_training()
(
hf_model,
hf_tokenizer,
llm_tokenizer,
model_config,
train_texts,
val_texts,
) = setup_hf_training()
# Тестируем интеграцию
test_hf_integration(hf_model, hf_tokenizer, llm_tokenizer)
@@ -173,7 +186,7 @@ def main():
from transformers import (
Trainer,
TrainingArguments,
DataCollatorForLanguageModeling
DataCollatorForLanguageModeling,
)
# Data collator для языкового моделирования
@@ -261,13 +274,15 @@ def main():
with torch.no_grad():
generated = hf_model.generate(
input_ids=inputs['input_ids'],
input_ids=inputs["input_ids"],
max_new_tokens=20,
do_sample=True,
temperature=0.8
temperature=0.8,
)
generated_text = hf_tokenizer.decode(generated[0], skip_special_tokens=True)
generated_text = hf_tokenizer.decode(
generated[0], skip_special_tokens=True
)
print(f"🎯 Результат: '{generated_text}'")
except Exception as e:
@@ -278,8 +293,8 @@ def main():
"experiment": experiment_name,
"model_config": model_config,
"training_config": TRAINING_CONFIG,
"final_loss": train_result.metrics.get('train_loss', 'N/A'),
"eval_loss": train_result.metrics.get('eval_loss', 'N/A')
"final_loss": train_result.metrics.get("train_loss", "N/A"),
"eval_loss": train_result.metrics.get("eval_loss", "N/A"),
}
logger.save_logs("checkpoints/hf_integration_training_logs.json")
@@ -291,6 +306,7 @@ def main():
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/gemma-bpe/config.json",
"model_weights": "checkpoints/gemma-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/gemma_only_generation_logs.json"
}

View File

@@ -0,0 +1,28 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"num_experts": 8,
"top_k_experts": 2,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/gemma-bpe/model.pt",
"model_config_path": "checkpoints/gemma-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/gemma_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Нейронные сети",
"Обработка естественного языка",
"GPT-2 — это"
],
"model_config_path": "checkpoints/gpt2-bpe/config.json",
"model_weights": "checkpoints/gpt2-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/llm_only_generation_logs.json"
}

View File

@@ -0,0 +1,23 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Искусственный интеллект", "Python — это"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_heads": 4,
"num_layers": 4,
"max_position_embeddings": 128,
"dropout": 0.1
},
"model_weights": "checkpoints/gpt2-bpe/model.pt",
"model_config_path": "checkpoints/gpt2-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/gpt2_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"The neural network",
"Transformer architecture",
"GPT models are"
],
"model_config_path": "checkpoints/gpt-bpe/config.json",
"model_weights": "checkpoints/gpt-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/llm_only_generation_logs.json"
}

View File

@@ -0,0 +1,23 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["GPT language model", "Machine learning basics"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_heads": 4,
"num_layers": 4,
"max_position_embeddings": 128,
"dropout": 0.1
},
"model_weights": "checkpoints/gpt-bpe/model.pt",
"model_config_path": "checkpoints/gpt-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/gpt_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/llama-bpe/config.json",
"model_weights": "checkpoints/llama-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/llm_only_generation_logs.json"
}

View File

@@ -0,0 +1,23 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_heads": 4,
"num_layers": 4,
"max_position_embeddings": 128,
"dropout": 0.1
},
"model_weights": "checkpoints/llama-bpe/model.pt",
"model_config_path": "checkpoints/llama-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/llama_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/mistral-bpe/config.json",
"model_weights": "checkpoints/mistral-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/mistral_only_generation_logs.json"
}

View File

@@ -0,0 +1,26 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/mistral-bpe/model.pt",
"model_config_path": "checkpoints/mistral-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/mistral_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/mixtral_only_generation_logs.json"
}

View File

@@ -0,0 +1,28 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"num_experts": 8,
"top_k_experts": 2,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/mixtral_only_training_logs.json"
}

View File

@@ -1,313 +0,0 @@
#!/usr/bin/env python3
"""
Experiment: generate_gpt_bpe.py
Description: Генерация текста обученной GPT моделью с BPE токенизатором.
Использует только библиотеку llm без зависимостей от HuggingFace.
"""
import torch
import os
import sys
# Добавляем путь к shared модулям
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.models.gpt import GPT2
from llm.tokenizers import BPETokenizer
from shared.configs import (
BASE_GPT_CONFIG, TEST_PROMPTS, GENERATION_CONFIG, PATHS
)
from shared.data import (
print_experiment_info, ensure_directories, ExperimentLogger
)
def load_model_and_tokenizer() -> tuple:
"""
Загружает обученную модель и токенизатор.
Returns:
tuple: (модель, токенизатор, конфигурация)
"""
# Проверяем существование файлов
if not os.path.exists(PATHS["gpt_bpe_model"]):
raise FileNotFoundError(
f"Модель не найдена: {PATHS['gpt_bpe_model']}\n"
f"Сначала обучите модель: uv run python experiments/llm_only/train_gpt_bpe.py"
)
if not os.path.exists(PATHS["bpe_tokenizer"]):
raise FileNotFoundError(
f"Токенизатор не найден: {PATHS['bpe_tokenizer']}"
)
# Загружаем конфигурацию модели
import json
with open(PATHS["gpt_bpe_config"], 'r', encoding='utf-8') as f:
model_config = json.load(f)
# Загружаем токенизатор
print("🔧 Загрузка BPE токенизатора...")
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
# Загружаем модель
print("🔧 Загрузка GPT2 модели...")
model = GPT2(model_config)
model.load_state_dict(torch.load(PATHS["gpt_bpe_model"], map_location='cpu'))
model.eval()
print("✅ Модель загружена")
return model, tokenizer, model_config
def generate_text(
model: GPT2,
tokenizer: BPETokenizer,
prompt: str,
config: dict
) -> str:
"""
Генерирует текст на основе промпта.
Args:
model: Обученная GPT модель
tokenizer: BPE токенизатор
prompt: Входной текст
config: Конфигурация генерации
Returns:
str: Сгенерированный текст
"""
print(f"🔤 Промпт: '{prompt}'")
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
f"temp={config['temperature']}, sample={config['do_sample']}")
# Кодируем промпт
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
print(f"🎯 Токены промпта: {input_ids}")
print(f"🎯 Токены (текст): {tokenizer.tokenize(prompt)}")
print("🔄 Генерация...")
# Генерируем текст
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=config["max_new_tokens"],
do_sample=config["do_sample"],
temperature=config["temperature"],
top_k=config["top_k"],
top_p=config["top_p"]
)
# Декодируем результат
generated_text = tokenizer.decode(generated_ids[0].tolist())
return generated_text
def test_different_strategies(model: GPT2, tokenizer: BPETokenizer, prompt: str):
"""
Тестирует разные стратегии генерации на одном промпте.
Args:
model: Обученная модель
tokenizer: BPE токенизатор
prompt: Тестовый промпт
"""
print(f"\n🎭 Сравнение стратегий генерации для промпта: '{prompt}'")
print("=" * 60)
strategies = [
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
]
for strategy in strategies:
print(f"\n{strategy['name']}:")
try:
config = GENERATION_CONFIG.copy()
config.update({
"do_sample": strategy["do_sample"],
"temperature": strategy["temperature"],
"max_new_tokens": 20
})
generated = generate_text(model, tokenizer, prompt, config)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
print(f" 📤 Промпт: '{prompt}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
print(f" 📄 Полный текст: '{generated}'")
except Exception as e:
print(f" ❌ Ошибка: {e}")
def analyze_tokenization(tokenizer: BPETokenizer, texts: list):
"""
Анализирует токенизацию различных текстов.
Args:
tokenizer: BPE токенизатор
texts: Список текстов для анализа
"""
print(f"\n🔍 Анализ токенизации BPE:")
print("=" * 50)
for i, text in enumerate(texts):
print(f"\nТекст {i+1}: '{text}'")
# Токенизация
tokens = tokenizer.encode(text, add_special_tokens=False)
token_strings = tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
print(f" Токены (текст): {token_strings}")
print(f" Количество токенов: {len(tokens)}")
print(f" Эффективность: {len(text)} символов → {len(tokens)} токенов")
# Декодирование обратно
decoded = tokenizer.decode(tokens)
if text == decoded:
print(f" ✅ Декодирование корректно")
else:
print(f" ⚠️ Расхождения: '{decoded}'")
def interactive_generation(model: GPT2, tokenizer: BPETokenizer):
"""
Режим интерактивной генерации.
Args:
model: Обученная модель
tokenizer: BPE токенизатор
"""
print(f"\n💬 Интерактивная генерация (для выхода введите 'exit')")
print("-" * 50)
while True:
try:
user_input = input("\n🔤 Введите промпт: ").strip()
if user_input.lower() in ['exit', 'quit', 'выход']:
break
if not user_input:
continue
# Запрашиваем параметры
try:
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
do_sample = do_sample_input != 'n'
except:
max_tokens = 50
temperature = 0.7
do_sample = True
print("⚠️ Использую параметры по умолчанию")
config = GENERATION_CONFIG.copy()
config.update({
"max_new_tokens": max_tokens,
"temperature": temperature,
"do_sample": do_sample
})
generated = generate_text(model, tokenizer, user_input, config)
generated_part = generated[len(user_input):]
print(f"\n🎯 Результат:")
print(f" 📤 Промпт: '{user_input}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
print(f" 📄 Полный текст: '{generated}'")
except KeyboardInterrupt:
print("\n👋 Завершение работы...")
break
except Exception as e:
print(f"❌ Ошибка: {e}")
def main():
"""Основная функция эксперимента."""
# === Настройка эксперимента ===
experiment_name = "Генерация текста GPT2 + BPE (только llm)"
experiment_config = {
"model": "GPT2 с BPE токенизатором",
"стратегия": "автономная генерация",
"вход": "промпты",
"выход": "сгенерированный текст"
}
print_experiment_info(experiment_name, experiment_config)
ensure_directories()
logger = ExperimentLogger(experiment_name)
try:
# Загружаем модель и токенизатор
model, tokenizer, model_config = load_model_and_tokenizer()
# === Анализ токенизации ===
analysis_texts = [
"Искусственный интеллект",
"Нейронные сети",
"Машинное обучение",
]
analyze_tokenization(tokenizer, analysis_texts)
# === Генерация с разными промптами ===
print(f"\n🎯 Генерация текста с разными промптами")
print("=" * 60)
for i, prompt in enumerate(TEST_PROMPTS):
print(f"\n📝 Пример {i+1}/{len(TEST_PROMPTS)}")
print("-" * 40)
try:
generated = generate_text(model, tokenizer, prompt, GENERATION_CONFIG)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
print(f"📤 Промпт: '{prompt}'")
print(f"🎯 Сгенерировано: '{generated_part}'")
print(f"📄 Полный текст: '{generated}'")
print(f"📏 Длина: {len(generated)} символов")
# Логируем успешную генерацию
logger.log_metric(f"generation_length_{i}", len(generated))
except Exception as e:
print(f"❌ Ошибка при генерации: {e}")
continue
# === Сравнение стратегий генерации ===
test_prompt = "Искусственный"
test_different_strategies(model, tokenizer, test_prompt)
# === Интерактивная генерация ===
interactive_generation(model, tokenizer)
# === Сохранение результатов ===
logger.save_logs("checkpoints/llm_only_generation_logs.json")
print(f"\n🎉 Эксперимент генерации завершен успешно!")
except FileNotFoundError as e:
print(f"{e}")
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,313 +0,0 @@
#!/usr/bin/env python3
"""
Experiment: generate_gpt_bpe.py
Description: Генерация текста обученной GPT моделью с BPE токенизатором.
Использует только библиотеку llm без зависимостей от HuggingFace.
"""
import torch
import os
import sys
# Добавляем путь к shared модулям
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.models.gpt import GPT
from llm.tokenizers import BPETokenizer
from shared.configs import (
BASE_GPT_CONFIG, TEST_PROMPTS, GENERATION_CONFIG, PATHS
)
from shared.data import (
print_experiment_info, ensure_directories, ExperimentLogger
)
def load_model_and_tokenizer() -> tuple:
"""
Загружает обученную модель и токенизатор.
Returns:
tuple: (модель, токенизатор, конфигурация)
"""
# Проверяем существование файлов
if not os.path.exists(PATHS["gpt_bpe_model"]):
raise FileNotFoundError(
f"Модель не найдена: {PATHS['gpt_bpe_model']}\n"
f"Сначала обучите модель: uv run python experiments/llm_only/train_gpt_bpe.py"
)
if not os.path.exists(PATHS["bpe_tokenizer"]):
raise FileNotFoundError(
f"Токенизатор не найден: {PATHS['bpe_tokenizer']}"
)
# Загружаем конфигурацию модели
import json
with open(PATHS["gpt_bpe_config"], 'r', encoding='utf-8') as f:
model_config = json.load(f)
# Загружаем токенизатор
print("🔧 Загрузка BPE токенизатора...")
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
# Загружаем модель
print("🔧 Загрузка GPT модели...")
model = GPT(model_config)
model.load_state_dict(torch.load(PATHS["gpt_bpe_model"], map_location='cpu'))
model.eval()
print("✅ Модель загружена")
return model, tokenizer, model_config
def generate_text(
model: GPT,
tokenizer: BPETokenizer,
prompt: str,
config: dict
) -> str:
"""
Генерирует текст на основе промпта.
Args:
model: Обученная GPT модель
tokenizer: BPE токенизатор
prompt: Входной текст
config: Конфигурация генерации
Returns:
str: Сгенерированный текст
"""
print(f"🔤 Промпт: '{prompt}'")
print(f"📊 Параметры: max_tokens={config['max_new_tokens']}, "
f"temp={config['temperature']}, sample={config['do_sample']}")
# Кодируем промпт
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
print(f"🎯 Токены промпта: {input_ids}")
print(f"🎯 Токены (текст): {tokenizer.tokenize(prompt)}")
print("🔄 Генерация...")
# Генерируем текст
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=config["max_new_tokens"],
do_sample=config["do_sample"],
temperature=config["temperature"],
top_k=config["top_k"],
top_p=config["top_p"]
)
# Декодируем результат
generated_text = tokenizer.decode(generated_ids[0].tolist())
return generated_text
def test_different_strategies(model: GPT, tokenizer: BPETokenizer, prompt: str):
"""
Тестирует разные стратегии генерации на одном промпте.
Args:
model: Обученная модель
tokenizer: BPE токенизатор
prompt: Тестовый промпт
"""
print(f"\n🎭 Сравнение стратегий генерации для промпта: '{prompt}'")
print("=" * 60)
strategies = [
{"name": "🎯 Жадный поиск", "do_sample": False, "temperature": 1.0},
{"name": "🎲 Вероятностная (temp=0.7)", "do_sample": True, "temperature": 0.7},
{"name": "🔥 Случайная (temp=1.2)", "do_sample": True, "temperature": 1.2},
{"name": "❄️ Детерминированная (temp=0.3)", "do_sample": True, "temperature": 0.3},
]
for strategy in strategies:
print(f"\n{strategy['name']}:")
try:
config = GENERATION_CONFIG.copy()
config.update({
"do_sample": strategy["do_sample"],
"temperature": strategy["temperature"],
"max_new_tokens": 20
})
generated = generate_text(model, tokenizer, prompt, config)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
print(f" 📤 Промпт: '{prompt}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
print(f" 📄 Полный текст: '{generated}'")
except Exception as e:
print(f" ❌ Ошибка: {e}")
def analyze_tokenization(tokenizer: BPETokenizer, texts: list):
"""
Анализирует токенизацию различных текстов.
Args:
tokenizer: BPE токенизатор
texts: Список текстов для анализа
"""
print(f"\n🔍 Анализ токенизации BPE:")
print("=" * 50)
for i, text in enumerate(texts):
print(f"\nТекст {i+1}: '{text}'")
# Токенизация
tokens = tokenizer.encode(text, add_special_tokens=False)
token_strings = tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
print(f" Токены (текст): {token_strings}")
print(f" Количество токенов: {len(tokens)}")
print(f" Эффективность: {len(text)} символов → {len(tokens)} токенов")
# Декодирование обратно
decoded = tokenizer.decode(tokens)
if text == decoded:
print(f" ✅ Декодирование корректно")
else:
print(f" ⚠️ Расхождения: '{decoded}'")
def interactive_generation(model: GPT, tokenizer: BPETokenizer):
"""
Режим интерактивной генерации.
Args:
model: Обученная модель
tokenizer: BPE токенизатор
"""
print(f"\n💬 Интерактивная генерация (для выхода введите 'exit')")
print("-" * 50)
while True:
try:
user_input = input("\n🔤 Введите промпт: ").strip()
if user_input.lower() in ['exit', 'quit', 'выход']:
break
if not user_input:
continue
# Запрашиваем параметры
try:
max_tokens = int(input("📏 Макс. токенов [50]: ") or "50")
temperature = float(input("🌡️ Температура [0.7]: ") or "0.7")
do_sample_input = input("🎲 Сэмплирование (y/n) [y]: ").lower()
do_sample = do_sample_input != 'n'
except:
max_tokens = 50
temperature = 0.7
do_sample = True
print("⚠️ Использую параметры по умолчанию")
config = GENERATION_CONFIG.copy()
config.update({
"max_new_tokens": max_tokens,
"temperature": temperature,
"do_sample": do_sample
})
generated = generate_text(model, tokenizer, user_input, config)
generated_part = generated[len(user_input):]
print(f"\n🎯 Результат:")
print(f" 📤 Промпт: '{user_input}'")
print(f" 🎯 Сгенерировано: '{generated_part}'")
print(f" 📄 Полный текст: '{generated}'")
except KeyboardInterrupt:
print("\n👋 Завершение работы...")
break
except Exception as e:
print(f"❌ Ошибка: {e}")
def main():
"""Основная функция эксперимента."""
# === Настройка эксперимента ===
experiment_name = "Генерация текста GPT + BPE (только llm)"
experiment_config = {
"model": "GPT с BPE токенизатором",
"стратегия": "автономная генерация",
"вход": "промпты",
"выход": "сгенерированный текст"
}
print_experiment_info(experiment_name, experiment_config)
ensure_directories()
logger = ExperimentLogger(experiment_name)
try:
# Загружаем модель и токенизатор
model, tokenizer, model_config = load_model_and_tokenizer()
# === Анализ токенизации ===
analysis_texts = [
"Искусственный интеллект",
"Нейронные сети",
"Машинное обучение",
]
analyze_tokenization(tokenizer, analysis_texts)
# === Генерация с разными промптами ===
print(f"\n🎯 Генерация текста с разными промптами")
print("=" * 60)
for i, prompt in enumerate(TEST_PROMPTS):
print(f"\n📝 Пример {i+1}/{len(TEST_PROMPTS)}")
print("-" * 40)
try:
generated = generate_text(model, tokenizer, prompt, GENERATION_CONFIG)
# Выделяем сгенерированную часть
generated_part = generated[len(prompt):]
print(f"📤 Промпт: '{prompt}'")
print(f"🎯 Сгенерировано: '{generated_part}'")
print(f"📄 Полный текст: '{generated}'")
print(f"📏 Длина: {len(generated)} символов")
# Логируем успешную генерацию
logger.log_metric(f"generation_length_{i}", len(generated))
except Exception as e:
print(f"❌ Ошибка при генерации: {e}")
continue
# === Сравнение стратегий генерации ===
test_prompt = "Искусственный"
test_different_strategies(model, tokenizer, test_prompt)
# === Интерактивная генерация ===
interactive_generation(model, tokenizer)
# === Сохранение результатов ===
logger.save_logs("checkpoints/llm_only_generation_logs.json")
print(f"\n🎉 Эксперимент генерации завершен успешно!")
except FileNotFoundError as e:
print(f"{e}")
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,176 @@
#!/usr/bin/env python3
"""
Универсальный скрипт для обучения и генерации LLM.
Позволяет выбирать тип модели и действие через аргументы,
а специальные параметры подавать отдельным JSON-конфигом.
"""
import argparse
import json
import os
import sys
import torch
# Добавляем директорию shared среди импортируемых
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.tokenizers import BPETokenizer
from llm.datasets.text_dataset import TextDataset
from llm.training.trainer import Trainer
from shared.data import (
print_experiment_info,
ensure_directories,
load_training_data,
ExperimentLogger,
)
def load_config(config_path):
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
def load_model_class(model_name):
if model_name.lower() == 'gpt':
from llm.models.gpt import GPT
return GPT
elif model_name.lower() == 'gpt2':
from llm.models.gpt import GPT2
return GPT2
elif model_name.lower() == 'llama':
from llm.models.llama import Llama
return Llama
elif model_name.lower() == 'mistral':
from llm.models.mistral import Mistral
return Mistral
elif model_name.lower() == 'mixtral':
from llm.models.mixtral import Mixtral
return Mixtral
elif model_name.lower() == 'gemma':
from llm.models.gemma import Gemma
return Gemma
else:
raise ValueError(f"Модель '{model_name}' не поддерживается.")
def main():
parser = argparse.ArgumentParser(description='Универсальный запуск обучения/генерации LLM.')
parser.add_argument('--model', '-m', type=str, required=True, help='Название модели (gpt, gpt2, llama и т.д.).')
parser.add_argument('--action', '-a', type=str, required=True, choices=['train', 'generate'], help='Действие: train или generate.')
parser.add_argument('--config', '-c', type=str, required=True, help='Путь к JSON-конфигу с параметрами.')
args = parser.parse_args()
config = load_config(args.config)
ModelClass = load_model_class(args.model)
logger = ExperimentLogger(f"{args.action}_{args.model}")
print_experiment_info(f"Эксперимент {args.action} {args.model}", config)
ensure_directories()
# ==== Обучение ====
if args.action == 'train':
train_texts, val_texts = load_training_data()
# --- Токенизатор ---
if os.path.exists(config["bpe_tokenizer"]):
print("📝 Загрузка обученного токенизатора...")
tokenizer = BPETokenizer.load(config["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
else:
print("🔧 Обучение BPE токенизатора...")
tokenizer = BPETokenizer()
tokenizer.train(
texts=train_texts,
vocab_size=config["bpe_vocab_size"],
special_tokens=config["bpe_special_tokens"]
)
os.makedirs(os.path.dirname(config["bpe_tokenizer"]), exist_ok=True)
tokenizer.save(config["bpe_tokenizer"])
print(f"✅ BPE токенизатор обучен и сохранен: {config['bpe_tokenizer']}")
# Тестируем токенизатор (базово)
for test_text in config.get("test_prompts", ["Тест"]):
encoded = tokenizer.encode(test_text)
decoded = tokenizer.decode(encoded)
print(f"[TEST TOK] '{test_text}'{encoded}'{decoded}'")
# --- Модель ---
model_config = config["model_config"]
model_config["vocab_size"] = tokenizer.get_vocab_size()
model = ModelClass(model_config)
# --- Датасет ---
train_dataset = TextDataset(
train_texts,
tokenizer,
block_size=model_config["max_position_embeddings"]
)
print(f" Размер train датасета: {len(train_dataset)} примеров")
# --- Trainer ---
training = config["training"]
trainer = Trainer(
model=model,
train_dataset=train_dataset,
lr=training["learning_rate"],
batch_size=training["batch_size"],
num_epochs=training["num_epochs"],
warmup_steps=training.get("warmup_steps", 0),
)
trainer.train()
# --- Сохранение модели ---
os.makedirs(os.path.dirname(config["model_weights"]), exist_ok=True)
torch.save(model.state_dict(), config["model_weights"])
with open(config["model_config_path"], "w", encoding="utf-8") as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)
print(f"✅ Модель сохранена: {config['model_weights']}")
logger.save_logs(config.get("log_path", "checkpoints/llm_only_training_logs.json"))
# ==== Генерация ====
elif args.action == 'generate':
# --- Загрузка ---
if not os.path.exists(config["model_weights"]):
raise FileNotFoundError(f"Модель не найдена: {config['model_weights']}")
if not os.path.exists(config["bpe_tokenizer"]):
raise FileNotFoundError(f"Токенизатор не найден: {config['bpe_tokenizer']}")
with open(config["model_config_path"], "r", encoding="utf-8") as f:
model_config = json.load(f)
tokenizer = BPETokenizer.load(config["bpe_tokenizer"])
model = ModelClass(model_config)
model.load_state_dict(torch.load(config["model_weights"], map_location="cpu"))
model.eval()
def generate(prompt, gen_cfg):
print(f"Промпт: {prompt}")
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=gen_cfg["max_new_tokens"],
do_sample=gen_cfg["do_sample"],
temperature=gen_cfg["temperature"],
top_k=gen_cfg.get("top_k"),
top_p=gen_cfg.get("top_p"),
)
return tokenizer.decode(generated_ids[0].tolist())
prompts = config.get("test_prompts", ["Тестовый промпт"])
gen_cfg = config.get("generation", {
"max_new_tokens": 50,
"temperature": 0.7,
"do_sample": True,
"top_k": None,
"top_p": None
})
for prompt in prompts:
generated = generate(prompt, gen_cfg)
print(f"\n[RESULT] Prompt: '{prompt}'\n---\n{generated}\n{'='*60}")
logger.save_logs(config.get("log_path", "checkpoints/llm_only_generation_logs.json"))
if __name__ == "__main__":
main()

View File

@@ -1,231 +0,0 @@
#!/usr/bin/env python3
"""
Experiment: train_gpt_bpe.py
Description: Обучение GPT модели с собственным BPE токенизатором.
Использует только библиотеку llm без зависимостей от HuggingFace.
"""
import torch
import os
import sys
# Добавляем путь к shared модулям
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.models.gpt import GPT2
from llm.tokenizers import BPETokenizer
from llm.training.dataset import TextDataset
from llm.training.trainer import Trainer
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TRAINING_CONFIG, PATHS, TEST_PROMPTS
)
from shared.data import (
load_training_data, ensure_directories,
print_experiment_info, ExperimentLogger
)
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
"""
Обучает BPE токенизатор на текстах.
Args:
texts: Список текстов для обучения
config: Конфигурация токенизатора
Returns:
BPETokenizer: Обученный токенизатор
"""
print("🔧 Обучение BPE токенизатора...")
tokenizer = BPETokenizer()
tokenizer.train(
texts=texts,
vocab_size=config["vocab_size"],
special_tokens=config["special_tokens"]
)
# Сохраняем токенизатор
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
tokenizer.save(PATHS["bpe_tokenizer"])
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
return tokenizer
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
"""
Тестирует токенизатор на примерах.
Args:
tokenizer: Обученный токенизатор
texts: Список тестовых текстов
"""
print("\n🧪 Тестирование токенизатора:")
for i, text in enumerate(texts[:3]):
print(f"\nПример {i+1}:")
print(f" Исходный текст: '{text}'")
# Кодирование
tokens = tokenizer.encode(text)
token_strings = tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
print(f" Токены (текст): {token_strings}")
print(f" Количество токенов: {len(tokens)}")
# Декодирование
decoded = tokenizer.decode(tokens)
print(f" Декодированный: '{decoded}'")
if text == decoded:
print(" ✅ Кодирование/декодирование корректно")
else:
print(" ⚠️ Небольшие расхождения")
def main():
"""Основная функция эксперимента."""
# === Настройка эксперимента ===
experiment_name = "Обучение GPT2 с BPE токенизатором (только llm)"
experiment_config = {
"model": "GPT2",
"tokenizer": "BPE",
"vocab_size": BPE_CONFIG["vocab_size"],
"training_epochs": TRAINING_CONFIG["num_epochs"],
"batch_size": TRAINING_CONFIG["batch_size"],
"learning_rate": TRAINING_CONFIG["learning_rate"]
}
print_experiment_info(experiment_name, experiment_config)
ensure_directories()
logger = ExperimentLogger(experiment_name)
try:
# === Подготовка данных ===
train_texts, val_texts = load_training_data()
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
# === Обучение токенизатора ===
if os.path.exists(PATHS["bpe_tokenizer"]):
print("📝 Загрузка предварительно обученного токенизатора...")
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
else:
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
# Тестируем токенизатор
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
# === Инициализация модели ===
model_config = BASE_GPT_CONFIG.copy()
model_config["vocab_size"] = tokenizer.get_vocab_size()
print(f"\n🔧 Инициализация GPT2 модели...")
print(f" Размер словаря: {model_config['vocab_size']}")
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
print(f" Количество слоев: {model_config['num_layers']}")
print(f" Количество голов внимания: {model_config['num_heads']}")
model = GPT2(model_config)
# === Подготовка датасета ===
print(f"\n📊 Подготовка датасета...")
train_dataset = TextDataset(
train_texts,
tokenizer,
block_size=model_config["max_position_embeddings"]
)
print(f" Размер train датасета: {len(train_dataset)} примеров")
# === Обучение модели ===
print(f"\n🎯 Начало обучения GPT2 модели...")
trainer = Trainer(
model=model,
train_dataset=train_dataset,
lr=TRAINING_CONFIG["learning_rate"],
batch_size=TRAINING_CONFIG["batch_size"],
num_epochs=TRAINING_CONFIG["num_epochs"],
warmup_steps=TRAINING_CONFIG["warmup_steps"]
)
# Запускаем обучение
trainer.train()
# === Сохранение модели ===
print(f"\n💾 Сохранение модели...")
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
# Сохраняем модель
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
# Сохраняем конфигурацию
import json
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)
print(f"✅ Модель сохранена:")
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
# === Тестирование генерации ===
print(f"\n🧪 Тестирование генерации текста...")
model.eval()
for prompt in TEST_PROMPTS[:3]:
print(f"\n🔤 Промпт: '{prompt}'")
try:
# Кодируем промпт
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
# Генерируем текст
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=20,
do_sample=True,
temperature=0.8
)
# Декодируем результат
generated_text = tokenizer.decode(generated_ids[0].tolist())
generated_part = generated_text[len(prompt):]
print(f"🎯 Сгенерировано: '{generated_part}'")
print(f"📄 Полный текст: '{generated_text}'")
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
# === Сохранение результатов ===
results = {
"experiment": experiment_name,
"model_config": model_config,
"training_config": TRAINING_CONFIG,
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
}
logger.save_logs("checkpoints/llm_only_training_logs.json")
print(f"\n🎉 Эксперимент завершен успешно!")
print(f"\n💡 Для использования обученной модели:")
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,231 +0,0 @@
#!/usr/bin/env python3
"""
Experiment: train_gpt_bpe.py
Description: Обучение GPT модели с собственным BPE токенизатором.
Использует только библиотеку llm без зависимостей от HuggingFace.
"""
import torch
import os
import sys
# Добавляем путь к shared модулям
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.models.gpt import GPT
from llm.tokenizers import BPETokenizer
from llm.training.dataset import TextDataset
from llm.training.trainer import Trainer
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TRAINING_CONFIG, PATHS, TEST_PROMPTS
)
from shared.data import (
load_training_data, ensure_directories,
print_experiment_info, ExperimentLogger
)
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
"""
Обучает BPE токенизатор на текстах.
Args:
texts: Список текстов для обучения
config: Конфигурация токенизатора
Returns:
BPETokenizer: Обученный токенизатор
"""
print("🔧 Обучение BPE токенизатора...")
tokenizer = BPETokenizer()
tokenizer.train(
texts=texts,
vocab_size=config["vocab_size"],
special_tokens=config["special_tokens"]
)
# Сохраняем токенизатор
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
tokenizer.save(PATHS["bpe_tokenizer"])
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
return tokenizer
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
"""
Тестирует токенизатор на примерах.
Args:
tokenizer: Обученный токенизатор
texts: Список тестовых текстов
"""
print("\n🧪 Тестирование токенизатора:")
for i, text in enumerate(texts[:3]):
print(f"\nПример {i+1}:")
print(f" Исходный текст: '{text}'")
# Кодирование
tokens = tokenizer.encode(text)
token_strings = tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
print(f" Токены (текст): {token_strings}")
print(f" Количество токенов: {len(tokens)}")
# Декодирование
decoded = tokenizer.decode(tokens)
print(f" Декодированный: '{decoded}'")
if text == decoded:
print(" ✅ Кодирование/декодирование корректно")
else:
print(" ⚠️ Небольшие расхождения")
def main():
"""Основная функция эксперимента."""
# === Настройка эксперимента ===
experiment_name = "Обучение GPT с BPE токенизатором (только llm)"
experiment_config = {
"model": "GPT",
"tokenizer": "BPE",
"vocab_size": BPE_CONFIG["vocab_size"],
"training_epochs": TRAINING_CONFIG["num_epochs"],
"batch_size": TRAINING_CONFIG["batch_size"],
"learning_rate": TRAINING_CONFIG["learning_rate"]
}
print_experiment_info(experiment_name, experiment_config)
ensure_directories()
logger = ExperimentLogger(experiment_name)
try:
# === Подготовка данных ===
train_texts, val_texts = load_training_data()
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
# === Обучение токенизатора ===
if os.path.exists(PATHS["bpe_tokenizer"]):
print("📝 Загрузка предварительно обученного токенизатора...")
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
else:
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
# Тестируем токенизатор
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
# === Инициализация модели ===
model_config = BASE_GPT_CONFIG.copy()
model_config["vocab_size"] = tokenizer.get_vocab_size()
print(f"\n🔧 Инициализация GPT модели...")
print(f" Размер словаря: {model_config['vocab_size']}")
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
print(f" Количество слоев: {model_config['num_layers']}")
print(f" Количество голов внимания: {model_config['num_heads']}")
model = GPT(model_config)
# === Подготовка датасета ===
print(f"\n📊 Подготовка датасета...")
train_dataset = TextDataset(
train_texts,
tokenizer,
block_size=model_config["max_position_embeddings"]
)
print(f" Размер train датасета: {len(train_dataset)} примеров")
# === Обучение модели ===
print(f"\n🎯 Начало обучения GPT модели...")
trainer = Trainer(
model=model,
train_dataset=train_dataset,
lr=TRAINING_CONFIG["learning_rate"],
batch_size=TRAINING_CONFIG["batch_size"],
num_epochs=TRAINING_CONFIG["num_epochs"],
warmup_steps=TRAINING_CONFIG["warmup_steps"]
)
# Запускаем обучение
trainer.train()
# === Сохранение модели ===
print(f"\n💾 Сохранение модели...")
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
# Сохраняем модель
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
# Сохраняем конфигурацию
import json
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)
print(f"✅ Модель сохранена:")
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
# === Тестирование генерации ===
print(f"\n🧪 Тестирование генерации текста...")
model.eval()
for prompt in TEST_PROMPTS[:3]:
print(f"\n🔤 Промпт: '{prompt}'")
try:
# Кодируем промпт
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
# Генерируем текст
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=20,
do_sample=True,
temperature=0.8
)
# Декодируем результат
generated_text = tokenizer.decode(generated_ids[0].tolist())
generated_part = generated_text[len(prompt):]
print(f"🎯 Сгенерировано: '{generated_part}'")
print(f"📄 Полный текст: '{generated_text}'")
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
# === Сохранение результатов ===
results = {
"experiment": experiment_name,
"model_config": model_config,
"training_config": TRAINING_CONFIG,
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
}
logger.save_logs("checkpoints/llm_only_training_logs.json")
print(f"\n🎉 Эксперимент завершен успешно!")
print(f"\n💡 Для использования обученной модели:")
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -1,231 +0,0 @@
#!/usr/bin/env python3
"""
Experiment: train_gpt_bpe.py
Description: Обучение GPT модели с собственным BPE токенизатором.
Использует только библиотеку llm без зависимостей от HuggingFace.
"""
import torch
import os
import sys
# Добавляем путь к shared модулям
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from llm.models.llama import Llama
from llm.tokenizers import BPETokenizer
from llm.training.dataset import TextDataset
from llm.training.trainer import Trainer
from shared.configs import (
TRAIN_TEXTS, BASE_GPT_CONFIG, BPE_CONFIG,
TRAINING_CONFIG, PATHS, TEST_PROMPTS
)
from shared.data import (
load_training_data, ensure_directories,
print_experiment_info, ExperimentLogger
)
def train_bpe_tokenizer(texts: list, config: dict) -> BPETokenizer:
"""
Обучает BPE токенизатор на текстах.
Args:
texts: Список текстов для обучения
config: Конфигурация токенизатора
Returns:
BPETokenizer: Обученный токенизатор
"""
print("🔧 Обучение BPE токенизатора...")
tokenizer = BPETokenizer()
tokenizer.train(
texts=texts,
vocab_size=config["vocab_size"],
special_tokens=config["special_tokens"]
)
# Сохраняем токенизатор
os.makedirs(os.path.dirname(PATHS["bpe_tokenizer"]), exist_ok=True)
tokenizer.save(PATHS["bpe_tokenizer"])
print(f"✅ BPE токенизатор обучен и сохранен: {PATHS['bpe_tokenizer']}")
print(f"📊 Размер словаря: {tokenizer.get_vocab_size()}")
return tokenizer
def test_tokenizer(tokenizer: BPETokenizer, texts: list):
"""
Тестирует токенизатор на примерах.
Args:
tokenizer: Обученный токенизатор
texts: Список тестовых текстов
"""
print("\n🧪 Тестирование токенизатора:")
for i, text in enumerate(texts[:3]):
print(f"\nПример {i+1}:")
print(f" Исходный текст: '{text}'")
# Кодирование
tokens = tokenizer.encode(text)
token_strings = tokenizer.tokenize(text)
print(f" Токены (ID): {tokens}")
print(f" Токены (текст): {token_strings}")
print(f" Количество токенов: {len(tokens)}")
# Декодирование
decoded = tokenizer.decode(tokens)
print(f" Декодированный: '{decoded}'")
if text == decoded:
print(" ✅ Кодирование/декодирование корректно")
else:
print(" ⚠️ Небольшие расхождения")
def main():
"""Основная функция эксперимента."""
# === Настройка эксперимента ===
experiment_name = "Обучение Llama с BPE токенизатором (только llm)"
experiment_config = {
"model": "Llama",
"tokenizer": "BPE",
"vocab_size": BPE_CONFIG["vocab_size"],
"training_epochs": TRAINING_CONFIG["num_epochs"],
"batch_size": TRAINING_CONFIG["batch_size"],
"learning_rate": TRAINING_CONFIG["learning_rate"]
}
print_experiment_info(experiment_name, experiment_config)
ensure_directories()
logger = ExperimentLogger(experiment_name)
try:
# === Подготовка данных ===
train_texts, val_texts = load_training_data()
print(f"📊 Данные: {len(train_texts)} train, {len(val_texts)} validation")
# === Обучение токенизатора ===
if os.path.exists(PATHS["bpe_tokenizer"]):
print("📝 Загрузка предварительно обученного токенизатора...")
tokenizer = BPETokenizer.load(PATHS["bpe_tokenizer"])
print(f"✅ Токенизатор загружен (vocab_size={tokenizer.get_vocab_size()})")
else:
tokenizer = train_bpe_tokenizer(TRAIN_TEXTS, BPE_CONFIG)
# Тестируем токенизатор
test_tokenizer(tokenizer, TEST_PROMPTS[:3])
# === Инициализация модели ===
model_config = BASE_GPT_CONFIG.copy()
model_config["vocab_size"] = tokenizer.get_vocab_size()
print(f"\n🔧 Инициализация Llama модели...")
print(f" Размер словаря: {model_config['vocab_size']}")
print(f" Размер эмбеддингов: {model_config['embed_dim']}")
print(f" Количество слоев: {model_config['num_layers']}")
print(f" Количество голов внимания: {model_config['num_heads']}")
model = Llama(model_config)
# === Подготовка датасета ===
print(f"\n📊 Подготовка датасета...")
train_dataset = TextDataset(
train_texts,
tokenizer,
block_size=model_config["max_position_embeddings"]
)
print(f" Размер train датасета: {len(train_dataset)} примеров")
# === Обучение модели ===
print(f"\n🎯 Начало обучения Llama модели...")
trainer = Trainer(
model=model,
train_dataset=train_dataset,
lr=TRAINING_CONFIG["learning_rate"],
batch_size=TRAINING_CONFIG["batch_size"],
num_epochs=TRAINING_CONFIG["num_epochs"],
warmup_steps=TRAINING_CONFIG["warmup_steps"]
)
# Запускаем обучение
trainer.train()
# === Сохранение модели ===
print(f"\n💾 Сохранение модели...")
os.makedirs(os.path.dirname(PATHS["gpt_bpe_model"]), exist_ok=True)
# Сохраняем модель
torch.save(model.state_dict(), PATHS["gpt_bpe_model"])
# Сохраняем конфигурацию
import json
with open(PATHS["gpt_bpe_config"], 'w', encoding='utf-8') as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)
print(f"✅ Модель сохранена:")
print(f" - {PATHS['gpt_bpe_model']}: веса модели")
print(f" - {PATHS['gpt_bpe_config']}: конфигурация модели")
print(f" - {PATHS['bpe_tokenizer']}: токенизатор")
# === Тестирование генерации ===
print(f"\n🧪 Тестирование генерации текста...")
model.eval()
for prompt in TEST_PROMPTS[:3]:
print(f"\n🔤 Промпт: '{prompt}'")
try:
# Кодируем промпт
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_tensor = torch.tensor([input_ids], dtype=torch.long)
# Генерируем текст
with torch.no_grad():
generated_ids = model.generate(
x=input_tensor,
max_new_tokens=20,
do_sample=True,
temperature=0.8
)
# Декодируем результат
generated_text = tokenizer.decode(generated_ids[0].tolist())
generated_part = generated_text[len(prompt):]
print(f"🎯 Сгенерировано: '{generated_part}'")
print(f"📄 Полный текст: '{generated_text}'")
except Exception as e:
print(f"❌ Ошибка генерации: {e}")
# === Сохранение результатов ===
results = {
"experiment": experiment_name,
"model_config": model_config,
"training_config": TRAINING_CONFIG,
"tokenizer_vocab_size": tokenizer.get_vocab_size(),
"final_loss": "см. логи обучения" # В реальном эксперименте можно сохранить final loss
}
logger.save_logs("checkpoints/llm_only_training_logs.json")
print(f"\n🎉 Эксперимент завершен успешно!")
print(f"\n💡 Для использования обученной модели:")
print(f" uv run python experiments/llm_only/generate_gpt_bpe.py")
except Exception as e:
print(f"❌ Ошибка в эксперименте: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()

View File

@@ -30,7 +30,7 @@ BASE_GPT_CONFIG = {
"num_heads": 4,
"num_layers": 4,
"max_position_embeddings": 128,
"dropout": 0.1
"dropout": 0.1,
}
# Конфигурация для маленькой модели (быстрое тестирование)
@@ -40,7 +40,7 @@ SMALL_GPT_CONFIG = {
"num_heads": 2,
"num_layers": 2,
"max_position_embeddings": 64,
"dropout": 0.1
"dropout": 0.1,
}
# Конфигурация для большой модели (качественное обучение)
@@ -50,13 +50,13 @@ LARGE_GPT_CONFIG = {
"num_heads": 8,
"num_layers": 6,
"max_position_embeddings": 256,
"dropout": 0.1
"dropout": 0.1,
}
# === Конфигурации токенизатора ===
BPE_CONFIG = {
"vocab_size": 1000,
"special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"]
"special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
}
# === Конфигурации обучения ===
@@ -65,7 +65,7 @@ TRAINING_CONFIG = {
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50,
"gradient_clip": 1.0
"gradient_clip": 1.0,
}
# === Конфигурации генерации ===
@@ -74,7 +74,7 @@ GENERATION_CONFIG = {
"temperature": 0.7,
"do_sample": True,
"top_k": None,
"top_p": None
"top_p": None,
}
# === Пути для сохранения ===
@@ -84,7 +84,7 @@ PATHS = {
"gpt_bpe_config": "checkpoints/gpt-bpe/config.json",
"hf_tokenizer": "checkpoints/hf-bpe-tokenizer",
"hf_model": "checkpoints/hf-trained",
"hf_proxy_model": "checkpoints/hf-trained-proxy"
"hf_proxy_model": "checkpoints/hf-trained-proxy",
}
# === Тестовые промпты ===

View File

@@ -32,7 +32,7 @@ def ensure_directories():
"checkpoints/hf-bpe-tokenizer",
"checkpoints/hf-trained",
"checkpoints/hf-trained-proxy",
"logs"
"logs",
]
for directory in directories:
@@ -52,15 +52,16 @@ def get_model_paths(experiment_type: str = "llm_only") -> dict:
base_paths = PATHS.copy()
if experiment_type == "hf_integration":
base_paths.update({
"model": base_paths["hf_model"],
"tokenizer": base_paths["hf_tokenizer"]
})
base_paths.update(
{"model": base_paths["hf_model"], "tokenizer": base_paths["hf_tokenizer"]}
)
else: # llm_only
base_paths.update({
"model": base_paths["gpt_bpe_model"],
"tokenizer": base_paths["bpe_tokenizer"]
})
base_paths.update(
{
"model": base_paths["gpt_bpe_model"],
"tokenizer": base_paths["bpe_tokenizer"],
}
)
return base_paths
@@ -92,7 +93,7 @@ def save_experiment_results(results: dict, filepath: str):
"""
import json
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2)
print(f"✅ Результаты эксперимента сохранены: {filepath}")
@@ -113,7 +114,7 @@ def load_experiment_results(filepath: str) -> dict:
if not os.path.exists(filepath):
return {}
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
return json.load(f)
@@ -151,12 +152,9 @@ class ExperimentLogger:
"""Сохраняет логи эксперимента."""
import json
logs = {
"experiment_name": self.experiment_name,
"metrics": self.metrics
}
logs = {"experiment_name": self.experiment_name, "metrics": self.metrics}
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(logs, f, ensure_ascii=False, indent=2)
print(f"✅ Логи эксперимента сохранены: {filepath}")

View File

@@ -27,16 +27,13 @@ __all__ = [
# Основные классы адаптера
"HFAdapter",
"HFGPTAdapter",
# Конфигурации
"HFAdapterConfig",
"HFPretrainedConfig",
# Адаптеры токенизаторов
"HFTokenizerAdapter",
"create_hf_tokenizer",
"convert_to_hf_format",
# Утилиты
"HFUtils",
"TokenizerWrapper",

View File

@@ -11,7 +11,7 @@ from transformers import (
GPT2Config,
GenerationConfig,
LogitsProcessorList,
StoppingCriteriaList
StoppingCriteriaList,
)
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
@@ -24,6 +24,7 @@ class HFGPTAdapter(PreTrainedModel):
Адаптер для модели GPT из библиотеки llm.
Позволяет использовать кастомные GPT модели с HuggingFace Transformers.
"""
config_class = HFPretrainedConfig
def __init__(self, config: HFPretrainedConfig, llm_model: Optional[GPT] = None):
@@ -46,7 +47,7 @@ class HFGPTAdapter(PreTrainedModel):
self.llm_model = llm_model
# Устанавливаем веса если они есть в конфигурации
if hasattr(config, 'state_dict') and config.state_dict is not None:
if hasattr(config, "state_dict") and config.state_dict is not None:
self.llm_model.load_state_dict(config.state_dict)
def _hf_to_llm_config(self, hf_config: HFPretrainedConfig) -> dict:
@@ -78,7 +79,7 @@ class HFGPTAdapter(PreTrainedModel):
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs
**kwargs,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
"""
Прямой проход модели.
@@ -96,7 +97,9 @@ class HFGPTAdapter(PreTrainedModel):
Returns:
CausalLMOutputWithCrossAttentions или кортеж
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# Основной forward pass
outputs = self.llm_model(input_ids)
@@ -114,8 +117,7 @@ class HFGPTAdapter(PreTrainedModel):
# Вычисляем cross-entropy loss
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
@@ -134,10 +136,7 @@ class HFGPTAdapter(PreTrainedModel):
)
def prepare_inputs_for_generation(
self,
input_ids: torch.Tensor,
past_key_values: Optional[Tuple] = None,
**kwargs
self, input_ids: torch.Tensor, past_key_values: Optional[Tuple] = None, **kwargs
) -> dict:
"""
Подготавливает входные данные для генерации.
@@ -163,7 +162,7 @@ class HFGPTAdapter(PreTrainedModel):
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
**kwargs
**kwargs,
) -> torch.Tensor:
"""
Генерация текста с поддержкой HuggingFace интерфейса.
@@ -179,8 +178,8 @@ class HFGPTAdapter(PreTrainedModel):
torch.Tensor: Сгенерированные токены
"""
# Извлекаем обязательные параметры из kwargs или используем значения по умолчанию
max_new_tokens = kwargs.pop('max_new_tokens', 50)
do_sample = kwargs.pop('do_sample', True)
max_new_tokens = kwargs.pop("max_new_tokens", 50)
do_sample = kwargs.pop("do_sample", True)
# Используем встроенную генерацию llm модели
return self.llm_model.generate(
@@ -188,7 +187,7 @@ class HFGPTAdapter(PreTrainedModel):
max_new_tokens=max_new_tokens,
do_sample=do_sample,
attention_mask=attention_mask,
**kwargs
**kwargs,
)
@@ -199,8 +198,7 @@ class HFAdapter:
@staticmethod
def from_llm_model(
llm_model: GPT,
hf_config: Optional[HFAdapterConfig] = None
llm_model: GPT, hf_config: Optional[HFAdapterConfig] = None
) -> HFGPTAdapter:
"""
Создает адаптер из существующей llm модели.
@@ -223,8 +221,7 @@ class HFAdapter:
@staticmethod
def from_pretrained(
model_path: str,
hf_config: Optional[HFAdapterConfig] = None
model_path: str, hf_config: Optional[HFAdapterConfig] = None
) -> HFGPTAdapter:
"""
Загружает модель из чекпоинта и создает адаптер.
@@ -237,14 +234,18 @@ class HFAdapter:
HFGPTAdapter: Адаптированная модель
"""
# Загружаем состояние модели
state_dict = torch.load(model_path, map_location='cpu')
state_dict = torch.load(model_path, map_location="cpu")
# Определяем конфигурацию из состояния модели или используем переданную
if hf_config is None:
# Пытаемся определить конфигурацию из состояния модели
# Это упрощенный подход - в реальности нужно сохранять конфигурацию отдельно
vocab_size = state_dict.get('_token_embeddings._embedding.weight', torch.zeros(50257, 768)).shape[0]
embed_dim = state_dict.get('_token_embeddings._embedding.weight', torch.zeros(50257, 768)).shape[1]
vocab_size = state_dict.get(
"_token_embeddings._embedding.weight", torch.zeros(50257, 768)
).shape[0]
embed_dim = state_dict.get(
"_token_embeddings._embedding.weight", torch.zeros(50257, 768)
).shape[1]
hf_config = HFAdapterConfig(
vocab_size=vocab_size,
@@ -270,11 +271,7 @@ class HFAdapter:
return HFGPTAdapter(pretrained_config, llm_model)
@staticmethod
def save_pretrained(
model: HFGPTAdapter,
save_directory: str,
**kwargs
):
def save_pretrained(model: HFGPTAdapter, save_directory: str, **kwargs):
"""
Сохраняет адаптированную модель в формате HuggingFace.
@@ -291,7 +288,7 @@ class HFAdapter:
# Сохраняем конфигурацию
config_path = os.path.join(save_directory, "config.json")
with open(config_path, 'w', encoding='utf-8') as f:
with open(config_path, "w", encoding="utf-8") as f:
json.dump(model.config.to_dict(), f, indent=2, ensure_ascii=False)
# Сохраняем веса модели
@@ -299,5 +296,5 @@ class HFAdapter:
torch.save(model.llm_model.state_dict(), model_path)
# Сохраняем токенизатор если передан
if hasattr(kwargs, 'tokenizer') and kwargs['tokenizer'] is not None:
kwargs['tokenizer'].save_pretrained(save_directory)
if hasattr(kwargs, "tokenizer") and kwargs["tokenizer"] is not None:
kwargs["tokenizer"].save_pretrained(save_directory)

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import Dict, Any, Optional
from transformers import PretrainedConfig
@dataclass
class HFAdapterConfig:
"""
@@ -28,6 +29,7 @@ class HFAdapterConfig:
eos_token_id: ID токена конца строки
bos_token_id: ID токена начала строки
"""
model_type: str = "gpt"
vocab_size: int = 50257
hidden_size: int = 768
@@ -52,8 +54,9 @@ class HFAdapterConfig:
def to_dict(self) -> Dict[str, Any]:
"""Преобразует конфигурацию в словарь."""
return {
k: v for k, v in self.__dict__.items()
if not k.startswith('_') and not callable(v)
k: v
for k, v in self.__dict__.items()
if not k.startswith("_") and not callable(v)
}
@classmethod
@@ -74,7 +77,7 @@ class HFAdapterConfig:
"num_heads": "num_attention_heads",
"max_position_embeddings": "max_position_embeddings",
"dropout": "hidden_dropout_prob",
"vocab_size": "vocab_size"
"vocab_size": "vocab_size",
}
hf_config_dict = {}
@@ -94,6 +97,7 @@ class HFPretrainedConfig(PretrainedConfig):
Конфигурация для предобученных моделей HuggingFace.
Наследуется от PretrainedConfig для полной совместимости.
"""
model_type = "gpt"
def __init__(
@@ -112,13 +116,13 @@ class HFPretrainedConfig(PretrainedConfig):
pad_token_id=50256,
eos_token_id=50256,
bos_token_id=50256,
**kwargs
**kwargs,
):
super().__init__(
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
bos_token_id=bos_token_id,
**kwargs
**kwargs,
)
self.vocab_size = vocab_size

View File

@@ -27,16 +27,16 @@ class HFTokenizerAdapter:
self.vocab_size = llm_tokenizer.get_vocab_size()
# Устанавливаем специальные токены
self.pad_token = getattr(llm_tokenizer, 'pad_token', '<pad>')
self.unk_token = getattr(llm_tokenizer, 'unk_token', '<unk>')
self.bos_token = getattr(llm_tokenizer, 'bos_token', '<bos>')
self.eos_token = getattr(llm_tokenizer, 'eos_token', '<eos>')
self.pad_token = getattr(llm_tokenizer, "pad_token", "<pad>")
self.unk_token = getattr(llm_tokenizer, "unk_token", "<unk>")
self.bos_token = getattr(llm_tokenizer, "bos_token", "<bos>")
self.eos_token = getattr(llm_tokenizer, "eos_token", "<eos>")
# Сохраняем ID специальных токенов
self.pad_token_id = getattr(llm_tokenizer, 'pad_token_id', 0)
self.unk_token_id = getattr(llm_tokenizer, 'unk_token_id', 1)
self.bos_token_id = getattr(llm_tokenizer, 'bos_token_id', 2)
self.eos_token_id = getattr(llm_tokenizer, 'eos_token_id', 3)
self.pad_token_id = getattr(llm_tokenizer, "pad_token_id", 0)
self.unk_token_id = getattr(llm_tokenizer, "unk_token_id", 1)
self.bos_token_id = getattr(llm_tokenizer, "bos_token_id", 2)
self.eos_token_id = getattr(llm_tokenizer, "eos_token_id", 3)
def __call__(self, text: str, **kwargs):
"""
@@ -49,30 +49,27 @@ class HFTokenizerAdapter:
Returns:
dict: Словарь с токенами
"""
return_tensors = kwargs.get('return_tensors', None)
padding = kwargs.get('padding', False)
truncation = kwargs.get('truncation', False)
max_length = kwargs.get('max_length', None)
add_special_tokens = kwargs.get('add_special_tokens', True)
return_tensors = kwargs.get("return_tensors", None)
padding = kwargs.get("padding", False)
truncation = kwargs.get("truncation", False)
max_length = kwargs.get("max_length", None)
add_special_tokens = kwargs.get("add_special_tokens", True)
# Кодируем текст
#input_ids = self.llm_tokenizer.encode(
# input_ids = self.llm_tokenizer.encode(
# text,
# add_special_tokens=add_special_tokens
#)
# )
if isinstance(text, str):
input_ids = self.llm_tokenizer.encode(
text,
add_special_tokens=add_special_tokens
text, add_special_tokens=add_special_tokens
)
input_ids = [input_ids] # <-- оборачиваем в batch
else:
# Список строк, батч-режим!
input_ids = [
self.llm_tokenizer.encode(
t,
add_special_tokens=add_special_tokens
) for t in text
self.llm_tokenizer.encode(t, add_special_tokens=add_special_tokens)
for t in text
]
# Применяем truncation
@@ -86,6 +83,7 @@ class HFTokenizerAdapter:
# Конвертируем в тензоры если нужно
if return_tensors == "pt":
import torch
input_ids = torch.tensor([input_ids])
return {"input_ids": input_ids}
@@ -99,7 +97,7 @@ class HFTokenizerAdapter:
truncation: bool = False,
max_length: Optional[int] = None,
return_tensors: Optional[str] = None,
**kwargs
**kwargs,
) -> Union[List[int], List[List[int]]]:
"""
Кодирует текст в последовательность токенов.
@@ -118,16 +116,12 @@ class HFTokenizerAdapter:
"""
# Кодируем основной текст
token_ids = self.llm_tokenizer.encode(
text,
add_special_tokens=add_special_tokens
text, add_special_tokens=add_special_tokens
)
# Обрабатываем text_pair если есть
if text_pair is not None:
pair_ids = self.llm_tokenizer.encode(
text_pair,
add_special_tokens=False
)
pair_ids = self.llm_tokenizer.encode(text_pair, add_special_tokens=False)
token_ids.extend(pair_ids)
# Применяем truncation
@@ -141,9 +135,11 @@ class HFTokenizerAdapter:
# Конвертируем в тензоры если нужно
if return_tensors == "pt":
import torch
return torch.tensor([token_ids])
elif return_tensors == "np":
import numpy as np
return np.array([token_ids])
return token_ids
@@ -152,7 +148,7 @@ class HFTokenizerAdapter:
self,
token_ids: Union[int, List[int], List[List[int]]],
skip_special_tokens: bool = True,
**kwargs
**kwargs,
) -> str:
"""
Декодирует последовательность токенов в текст.
@@ -167,13 +163,22 @@ class HFTokenizerAdapter:
# Обрабатываем разные форматы входных данных
if isinstance(token_ids, int):
token_ids = [token_ids]
elif isinstance(token_ids, list) and len(token_ids) > 0 and isinstance(token_ids[0], list):
elif (
isinstance(token_ids, list)
and len(token_ids) > 0
and isinstance(token_ids[0], list)
):
# Список списков - берем первый элемент
token_ids = token_ids[0]
# Фильтруем специальные токены если нужно
if skip_special_tokens:
special_ids = {self.pad_token_id, self.unk_token_id, self.bos_token_id, self.eos_token_id}
special_ids = {
self.pad_token_id,
self.unk_token_id,
self.bos_token_id,
self.eos_token_id,
}
token_ids = [tid for tid in token_ids if tid not in special_ids]
return self.llm_tokenizer.decode(token_ids)
@@ -224,8 +229,12 @@ class HFTokenizerAdapter:
# Обрабатываем разные типы данных
if isinstance(input_ids, int):
seq_len = 1
elif hasattr(input_ids, 'shape'):
seq_len = input_ids.shape[-1] if len(input_ids.shape) > 1 else len(input_ids)
elif hasattr(input_ids, "shape"):
seq_len = (
input_ids.shape[-1]
if len(input_ids.shape) > 1
else len(input_ids)
)
else:
seq_len = len(input_ids)
max_len = max(max_len, seq_len)
@@ -240,8 +249,12 @@ class HFTokenizerAdapter:
# Получаем текущую длину
if isinstance(input_ids, int):
current_len = 1
elif hasattr(input_ids, 'shape'):
current_len = input_ids.shape[-1] if len(input_ids.shape) > 1 else len(input_ids)
elif hasattr(input_ids, "shape"):
current_len = (
input_ids.shape[-1]
if len(input_ids.shape) > 1
else len(input_ids)
)
else:
current_len = len(input_ids)
@@ -251,20 +264,27 @@ class HFTokenizerAdapter:
# Обрабатываем разные типы данных
if isinstance(input_ids, int):
item["input_ids"] = [input_ids] + [self.pad_token_id] * padding_length
elif hasattr(input_ids, 'shape'):
item["input_ids"] = [input_ids] + [
self.pad_token_id
] * padding_length
elif hasattr(input_ids, "shape"):
import torch
padding_tensor = torch.full((padding_length,), self.pad_token_id, dtype=input_ids.dtype)
padding_tensor = torch.full(
(padding_length,), self.pad_token_id, dtype=input_ids.dtype
)
item["input_ids"] = torch.cat([input_ids, padding_tensor])
else:
item["input_ids"] = input_ids + [self.pad_token_id] * padding_length
item["input_ids"] = (
input_ids + [self.pad_token_id] * padding_length
)
# Добавляем attention_mask если требуется
if "attention_mask" in item:
mask = item["attention_mask"]
if isinstance(mask, int):
item["attention_mask"] = [mask] + [0] * padding_length
elif hasattr(mask, 'shape'):
elif hasattr(mask, "shape"):
padding_mask = torch.zeros(padding_length, dtype=mask.dtype)
item["attention_mask"] = torch.cat([mask, padding_mask])
else:
@@ -272,16 +292,21 @@ class HFTokenizerAdapter:
elif return_attention_mask:
if isinstance(input_ids, int):
item["attention_mask"] = [1] + [0] * padding_length
elif hasattr(input_ids, 'shape'):
elif hasattr(input_ids, "shape"):
attention_mask = torch.ones(current_len, dtype=torch.long)
padding_mask = torch.zeros(padding_length, dtype=torch.long)
item["attention_mask"] = torch.cat([attention_mask, padding_mask])
item["attention_mask"] = torch.cat(
[attention_mask, padding_mask]
)
else:
item["attention_mask"] = [1] * current_len + [0] * padding_length
item["attention_mask"] = [1] * current_len + [
0
] * padding_length
# Конвертируем в тензоры если требуется
if return_tensors == "pt":
import torch
for key in list(encoded_inputs[0].keys()):
if isinstance(encoded_inputs[0][key], list):
for i in range(len(encoded_inputs)):
@@ -326,12 +351,12 @@ class HFTokenizerAdapter:
}
config_path = os.path.join(save_directory, "tokenizer_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
with open(config_path, "w", encoding="utf-8") as f:
json.dump(tokenizer_config, f, ensure_ascii=False, indent=2)
# Сохраняем словарь
vocab_path = os.path.join(save_directory, "vocab.json")
with open(vocab_path, 'w', encoding='utf-8') as f:
with open(vocab_path, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
print(f"✅ Токенизатор сохранен в {save_directory}")
@@ -353,7 +378,9 @@ class HFTokenizerAdapter:
# Проверяем, является ли путь директорией с файлами токенизатора
if os.path.isdir(pretrained_model_name_or_path):
# Загружаем из директории
config_path = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json")
config_path = os.path.join(
pretrained_model_name_or_path, "tokenizer_config.json"
)
vocab_path = os.path.join(pretrained_model_name_or_path, "vocab.json")
if not os.path.exists(config_path) or not os.path.exists(vocab_path):
@@ -362,7 +389,7 @@ class HFTokenizerAdapter:
)
# Загружаем конфигурацию
with open(config_path, 'r', encoding='utf-8') as f:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
# Определяем тип токенизатора llm
@@ -373,7 +400,7 @@ class HFTokenizerAdapter:
llm_tokenizer = BPETokenizer()
# Загружаем словарь
with open(vocab_path, 'r', encoding='utf-8') as f:
with open(vocab_path, "r", encoding="utf-8") as f:
vocab = json.load(f)
llm_tokenizer.vocab = vocab
@@ -393,7 +420,9 @@ class HFTokenizerAdapter:
return cls(llm_tokenizer, **kwargs)
else:
raise ValueError(f"Неподдерживаемый тип токенизатора: {llm_tokenizer_type}")
raise ValueError(
f"Неподдерживаемый тип токенизатора: {llm_tokenizer_type}"
)
else:
# Пытаемся загрузить как файл llm токенизатора

View File

@@ -31,9 +31,7 @@ class HFUtils:
@staticmethod
def convert_to_hf_format(
llm_model,
tokenizer = None,
model_name: str = "custom-gpt"
llm_model, tokenizer=None, model_name: str = "custom-gpt"
) -> tuple:
"""
Конвертирует llm модель в формат HuggingFace.
@@ -52,13 +50,17 @@ class HFUtils:
# Если токенизатор не передан, создаем стандартный
if tokenizer is None:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
# Устанавливаем специальные токены
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
elif hasattr(tokenizer, '__class__') and 'BPETokenizer' in str(tokenizer.__class__):
elif hasattr(tokenizer, "__class__") and "BPETokenizer" in str(
tokenizer.__class__
):
# Если передан наш кастомный токенизатор, создаем адаптер
from .hf_tokenizer import create_hf_tokenizer
tokenizer = create_hf_tokenizer(tokenizer)
return hf_model, tokenizer
@@ -70,7 +72,7 @@ class HFUtils:
repo_name: str,
organization: Optional[str] = None,
private: bool = False,
**kwargs
**kwargs,
):
"""
Загружает модель в HuggingFace Hub.
@@ -116,7 +118,7 @@ class HFUtils:
api.upload_folder(
folder_path=tmp_dir,
repo_id=repo_id,
commit_message="Initial commit with custom GPT model"
commit_message="Initial commit with custom GPT model",
)
print(f"✅ Модель успешно загружена в HuggingFace Hub: {repo_id}")
@@ -128,10 +130,7 @@ class HFUtils:
)
@staticmethod
def load_from_hub(
repo_id: str,
**kwargs
) -> tuple:
def load_from_hub(repo_id: str, **kwargs) -> tuple:
"""
Загружает модель из HuggingFace Hub.
@@ -162,17 +161,14 @@ class HFUtils:
# Загружаем модель через адаптер
model = HFAdapter.from_pretrained(
f"{repo_id}/pytorch_model.bin",
HFAdapterConfig.from_llm_config(llm_config)
f"{repo_id}/pytorch_model.bin", HFAdapterConfig.from_llm_config(llm_config)
)
return model, tokenizer
@staticmethod
def compare_with_hf_model(
llm_model,
hf_model_name: str = "gpt2",
test_input: str = "Hello world"
llm_model, hf_model_name: str = "gpt2", test_input: str = "Hello world"
) -> Dict[str, Any]:
"""
Сравнивает llm модель с эталонной моделью из HuggingFace.
@@ -197,7 +193,7 @@ class HFUtils:
# Получаем логиты от обеих моделей
with torch.no_grad():
hf_logits = hf_model(**inputs).logits
llm_logits = llm_model(inputs['input_ids'])
llm_logits = llm_model(inputs["input_ids"])
# Сравниваем результаты
hf_probs = torch.softmax(hf_logits[0, -1], dim=-1)
@@ -205,15 +201,11 @@ class HFUtils:
# Вычисляем метрики
kl_divergence = torch.nn.functional.kl_div(
torch.log(llm_probs + 1e-8),
hf_probs,
reduction='batchmean'
torch.log(llm_probs + 1e-8), hf_probs, reduction="batchmean"
)
cosine_similarity = torch.nn.functional.cosine_similarity(
hf_logits.flatten(),
llm_logits.flatten(),
dim=0
hf_logits.flatten(), llm_logits.flatten(), dim=0
)
return {
@@ -244,11 +236,7 @@ class TokenizerWrapper:
Dict: Токенизированные данные
"""
return self.tokenizer(
texts,
padding=True,
truncation=True,
return_tensors="pt",
**kwargs
texts, padding=True, truncation=True, return_tensors="pt", **kwargs
)
def decode_batch(self, token_ids: torch.Tensor, **kwargs) -> List[str]:
@@ -268,9 +256,7 @@ class TokenizerWrapper:
texts = []
for i in range(token_ids.size(0)):
text = self.tokenizer.decode(
token_ids[i],
skip_special_tokens=True,
**kwargs
token_ids[i], skip_special_tokens=True, **kwargs
)
texts.append(text)
@@ -290,12 +276,7 @@ class TokenizerWrapper:
}
def create_hf_pipeline(
llm_model,
tokenizer=None,
device: str = "auto",
**kwargs
):
def create_hf_pipeline(llm_model, tokenizer=None, device: str = "auto", **kwargs):
"""
Создает HuggingFace pipeline из llm модели.
@@ -315,11 +296,7 @@ def create_hf_pipeline(
# Создаем pipeline
pipe = pipeline(
"text-generation",
model=hf_model,
tokenizer=tokenizer,
device=device,
**kwargs
"text-generation", model=hf_model, tokenizer=tokenizer, device=device, **kwargs
)
return pipe

View File

@@ -27,14 +27,19 @@ llm/
│ │ ├── gpt.py # Базовая GPT
│ │ ├── gpt2.py # GPT-2 реализация
│ │ └── __init__.py
── llama/ # LLaMA архитектура
├── llama.py # LLaMA реализация
── llama/ # LLaMA архитектура
├── llama.py # LLaMA реализация
│ │ └── __init__.py
│ └── mistral/ # Mistral архитектура
│ ├── mistral.py # Mistral реализация
│ └── __init__.py
├── tokenizers/ # Токенизаторы
│ ├── base_tokenizer.py # Базовый интерфейс
│ └── bpe_tokenizer.py # BPE токенизатор
├── datasets/ # Работа с датасетами
│ ├── text_dataset.py # Стандартный датасет
│ └── streaming_text_dataset.py # Стриминговый датасет
└── training/ # Утилиты обучения
├── dataset.py # Датасеты
├── trainer.py # Тренировочный цикл
├── optimizer.py # Оптимизаторы
└── scheduler.py # Планировщики обучения
@@ -176,12 +181,11 @@ generated = model.generate(input_ids, max_length=100)
- ✅ Базовая архитектура трансформер-декодера
### GPT-2 Особенности
- ✅ Улучшенная версия оригинальной GPT
- ✅ Layer Normalization (перед вниманием и FFN)
- ✅ GELU активация
- ✅ Learned positional embeddings
- ✅ Кэширование для эффективной генерации
-Оптимизированные веса инициализации
- ✅ Кэширование KV для быстрой генерации
-Улучшенная инициализация слоёв
### LLaMA Особенности
- ✅ Rotary Positional Embeddings (RoPE)
@@ -190,6 +194,21 @@ generated = model.generate(input_ids, max_length=100)
- ✅ Оптимизированная структура декодера
- ✅ Эффективное кэширование KV-памяти
### Mistral Особенности
- ✅ Sliding Window Attention (оконное внимание)
- ✅ Grouped Query Attention (GQA)
- ✅ RoPE
- ✅ RMSNorm
- ✅ Разделённая архитектура на блоки с эффективным управлением памятью
- ✅ Совместимость с HuggingFace через hf-proxy
## 🤝 Интеграция с HuggingFace и BPE
- Встроенная поддержка собственных BPE токенизаторов и экспериментальная поддержка токенизаторов через HuggingFace (см. hf-proxy).
- hf-proxy — экспериментальный модуль! Совместимость с будущими версиями Transformers не гарантируется; API может меняться.
- Допускается загрузка/конвертация моделей в формат HF для использования экосистемы Transformers.
- Для запуска моделей с токенизаторами HF используйте `hf-proxy` и соответствующие эксперименты из `experiments/hf_integration/`.
## 🧪 Тестирование
Запуск всех тестов:
@@ -198,7 +217,7 @@ cd llm
python -m pytest tests/ -v
```
**Статус тестов:** ✅ 101 тест пройден
**Статус тестов:** ✅ 101+ тест, охвачены все основные компоненты (ядро, ядро-токенизация, архитектуры, обучение)
## 📚 Научные концепции

View File

@@ -19,6 +19,7 @@ from abc import ABC, abstractmethod
from typing import Optional, Tuple
import torch
class BaseModel(nn.Module, ABC):
"""
Абстрактный класс — стандарт для всех архитектур LLM.
@@ -32,6 +33,7 @@ class BaseModel(nn.Module, ABC):
Attributes:
config (dict): Конфиг модели
"""
def __init__(self, config: dict):
"""
Инициализация модели.
@@ -43,7 +45,9 @@ class BaseModel(nn.Module, ABC):
self.config = config
@abstractmethod
def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Прямой проход — получение логитов для входных токенов.

View File

@@ -6,36 +6,51 @@ from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention
from .rope import RoPE
class CachedDecoder(nn.Module):
"""
Универсальный декодерный блок для современных LLM (GPT, LLaMA, др.), поддерживает кэширование key-value для эффективной генерации.
CachedDecoder — Transformer-декодер с key/value-кэшированием (реализация накладывающегося masked multi-head attention).
Научная идея:
Автопагрессивная авторегрессия в трансформерах требует быстрого доступа к ранее вычисленным self-attention ключам/значениям — этот класс позволяет прозрачно кэшировать такие состояния для быстрой инференс-генерации.
Назначение:
-----------
Позволяет быстро и эффективно реализовывать autoregressive генерацию текста в стиле GPT-2/3/4:
- На шаге генерации используются только нужные токены, “прошлые” key/value значения не пересчитываются, а подаются из кэша.
- Позволяет значительно ускорять inferece (особенно на длинных последовательностях).
- Вдохновлено реализациями в HuggingFace transformers, GPT-2/3 и других LLM.
Алгоритм:
- Input -> LayerNorm -> Многоголовое внимание с кэшем (может быть RoPE)
- Суммируем residual
- LayerNorm -> FeedForward (любой, например SwiGLU) -> Residual
- Возвращается кортеж (output, kvcache)
Архитектурные особенности:
--------------------------
- Использует классическую multi-head attention (с causal mask — запрещает видеть “будущее”).
- Предусматривает передачу и накопление KV-cache для каждого слоя (hidden state attention).
- Поддерживает передачу внимания через стек attention-блоков.
- Применяется layernorm и feed-forward block (GELU).
Args:
feed_forward_layer (nn.Module): FeedForward или SwiGLU слой
num_heads (int): Количество голов внимания
emb_size (int): Размерность эмбеддингов
head_size (int): Размерность головы внимания
max_seq_len (int): Максимальная длина
norm_layer (тип nn.Module): Normalization слой (LayerNorm или RMSNorm)
dropout (float): Dropout
rope (RoPE|None): Экземпляр RoPE (для LLaMA)
Параметры конструктора:
-----------------------
num_heads : int — число attention heads
emb_size : int — embedding размерность
head_size : intразмер каждой attention head (обычно emb_size // num_heads)
feed_forward_layer : nn.Module — feedforward блок (mLP), может быть любым PyTorch-слоем
max_seq_len : int — максимально допустимая длина последовательности
dropout : float — dropout на attention/ffn
Пример использования:
---------------------
>>> from llm.core.feed_forward import FeedForward
>>> ff_block = FeedForward(emb_size=256, dropout=0.1, activation=\"gelu\")
>>> decoder = CachedDecoder(num_heads=4, emb_size=256, head_size=64, feed_forward_layer=ff_block, max_seq_len=2048, dropout=0.1)
>>> x = torch.randn(2, 100, 256)
>>> y, kv_cache = decoder(x, use_cache=True, cache=None)
>>> print(y.shape) # torch.Size([2, 100, 256])
Подробнее:
----------
- GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
- HuggingFace cache mechanics: https://huggingface.co/docs/transformers/main/en/model_doc/gpt2
- Объяснения autoregressive cache: https://jalammar.github.io/illustrated-gpt2/
Пример (GPT2 style):
>>> decoder = CachedDecoder(
... feed_forward_layer=FeedForward(...),
... norm_layer=nn.LayerNorm,
... num_heads=4, emb_size=256, head_size=64, max_seq_len=128)
>>> out, cache = decoder(x, use_cache=True)
"""
def __init__(
self,
feed_forward_layer: nn.Module,
@@ -48,20 +63,22 @@ class CachedDecoder(nn.Module):
rope: RoPE = None,
):
"""
Инициализация декодера с кэшированием.
Конструктор CachedDecoder.
Поведение аналогично блоку TransformerDecoderLayer,
но с гибкой возможностью подмены любых подкомпонент (активация, norm, позиции).
Args:
feed_forward_layer: Слой feed-forward (должен быть экземпляром, а не классом)
num_heads: Количество голов внимания
emb_size: Размерность эмбеддингов
head_size: Размерность каждой головы
max_seq_len: Максимальная длина последовательности
norm_layer: Класс нормализации (по умолчанию LayerNorm)
dropout: Вероятность dropout
rope: Rotary Positional Embeddings (опционально)
Аргументы:
----------
num_heads : int
Сколько attention heads используется в каждом attention слое.
emb_size : int
Размерность входного вектора x.
head_size : int
Размерность каждой attention head; emb_size = num_heads * head_size должно быть True!
feed_forward_layer : nn.Module
Feed-forward слой (например, обычный двухслойный MLP), который применяется после нормы и внимания, и после второй нормы.
max_seq_len : int
Максимальная поддерживаемая длина последовательности (выделяет буфер для causal-маски).
dropout : float, default=0.1
Dropout после внимания и/или feedforward.
"""
super().__init__()
self._heads = MultiHeadAttention(
@@ -84,19 +101,30 @@ class CachedDecoder(nn.Module):
cache: list = None,
):
"""
Прямой проход с поддержкой кэша.
Прямой проход через Decoder Block с поддержкой KV-кэша.
В этом методе применяется:
- Causal multi-head attention (masked, не смотрит вперёд)
- Быстрая обработка длинных последовательностей за счёт сохранения и передачи KV-кэша
- LayerNorm перед каждым блоком
- Feed-forward блок и вторая LayerNorm
- Dropout
Аргументы:
----------
x : torch.Tensor
Вход [batch, seq_len, emb_size]
use_cache : bool, по умолчанию True
Включать ли накопление и возврат KV-кэша для autoregressive inferece.
cache : list, опционально
Список предыдущего KV-кеша для attention.
Возвращает:
-----------
x_ff_out : torch.Tensor
Результат после attention, модуля и их рез. связей (shape == x)
new_cache : new KV-cache (или None)
Args:
x (Tensor[float]): [batch, seq_len, emb_size] — скрытые состояния
mask (Optional[Tensor]): маска внимания (или causal mask), shape [seq_len, seq_len]
use_cache (bool): использовать кэширование KV
cache (list): кэш self-attention для быстрого авторегрессива
Returns:
output (Tensor[float]): выходные состояния [batch, seq_len, emb_size]
kv_caches (list): обновленный кэш, если use_cache
Пример:
>>> out, new_cache = decoder(x, use_cache=True, cache=old_cache)
>>> out.shape # [batch, seq_len, emb_size]
"""
norm1_out = self._norm1(x)
# Передаём все cache/use_cache дальше в attention

View File

@@ -1,79 +0,0 @@
from torch import nn
import torch
from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention
class Decoder(nn.Module):
"""
Базовый автогерессивный блок-декодер трансформера (без кэша KV).
Научная суть:
- Осуществляет посимвольное предсказание: каждый токен видит только предыдущие (masked attention)
- Состоит из self-attention + feedforward + residual + нормализация
- Residual connection и normalization дают стабильность и градиентный “flow” при обучении
- Механизм предложен в Vaswani et al., "Attention is All You Need", 2017
Args:
num_heads (int): количество attention-голов
emb_size (int): размер эмбеддинга
head_size (int): размер одной attention-головы
max_seq_len (int): максимальная длина последовательности
dropout (float): вероятность dropout
Пример:
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
>>> x = torch.randn(1, 10, 512)
>>> out = decoder(x)
>>> print(out.shape) # torch.Size([1, 10, 512])
"""
def __init__(self,
num_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
dropout: float = 0.1
):
"""
Инициализация декодера.
Параметры:
num_heads: int - количество голов внимания
emb_size: int - размерность эмбеддингов
head_size: int - размерность каждой головы внимания
max_seq_len: int - максимальная длина последовательности
dropout: float (default=0.1) - вероятность dropout
"""
super().__init__()
self._heads = MultiHeadAttention(
num_heads=num_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
dropout=dropout
)
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
self._norm1 = nn.LayerNorm(emb_size)
self._norm2 = nn.LayerNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Прямой проход через декодер.
Вход:
x: torch.Tensor - входной тензор [batch_size, seq_len, emb_size]
mask: torch.Tensor (optional) - маска внимания [seq_len, seq_len]
Возвращает:
torch.Tensor - выходной тензор [batch_size, seq_len, emb_size]
Алгоритм forward:
1. Применяем MultiHeadAttention к входу
2. Добавляем residual connection и LayerNorm
3. Применяем FeedForward сеть
4. Добавляем residual connection и LayerNorm
"""
# Self-Attention блок
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
out = self._norm1(attention + x)
# FeedForward блок
ffn_out = self._ff(out)
return self._norm2(ffn_out + out)

View File

@@ -6,52 +6,71 @@ from .gelu import GELU
class FeedForward(nn.Module):
"""
Классический слой прямого распространения (FeedForward, или FFN) для архитектуры Transformer.
FeedForward — классический позиционно-независимый блок для Transformer, применяется к каждому токену отдельно.
Этот слой состоит из двух линейных преобразований с расширением внутренней размерности
в 4 раза и механизмом dropout для регуляризации. Между линейными слоями применяется
активация ReLU.
Назначение и роль:
------------------
- Реализует двухслойную (или более сложную) нейронную сеть, которая обрабатывает каждый токен ПОРЯДОЧНО независимо (по последней измерении).
- Дает модели "нелинейную мощность": любой токен может быть переосмыслен вне глобального контекста.
- После слоя внимания (MHA) FFN помогает связать смысл локальных (внутри токена) “скрытых” значений.
Научная суть:
- После внимания каждому токену применяется одинаковая двухслойная нейросеть.
- Дает глубокую нелинейность; позволяет модели не только сопоставлять, но и моделировать сложные связи между токенами.
- Изначально предложен в «Attention is All You Need» (Vaswani et al., 2017).
Архитектурные детали:
---------------------
- Обычно используется блок: (Linear → Activation → Dropout → Linear → Dropout)
- В современных LLM обычно в 4 раза расширяют скрытый слой (inner_dim = 4 * emb_size).
- Активация часто GELU или SiLU (Swish), иногда SwiGLU, ReGLU, GeGLU (см. PaLM, Llama).
Формула:
FFN(x) = Dropout(W2·act(W1·x))
где act — ReLU, GELU и др., обычно expansion x4.
Формула (обычная версия):
-------------------------
FFN(x) = Linear2(Dropout(Activation(Linear1(x))))
где Linear1: [emb_size → 4*emb_size], Activation: GELU/SiLU, Linear2: [4*emb_size → emb_size]
Алгоритм работы:
1. Входной тензор x (размерность: [batch_size, seq_len, emb_size])
2. Линейное преобразование: emb_size -> 4*emb_size
3. Активация ReLU
4. Линейное преобразование: 4*emb_size -> emb_size
5. Применение dropout
6. Возврат результата (размерность: [batch_size, seq_len, emb_size])
Параметры конструктора:
-----------------------
emb_size: int — размерность входа/выхода токена
inner_dim: int (необязательно) — размер скрытого слоя (по умолчанию 4*emb_size)
activation: str — тип активации ('gelu', 'silu', 'relu', ...), см. варианты ниже
dropout: float — dropout после каждой линейной проекции
Предназначение:
- Добавляет нелинейность в архитектуру трансформера
- Обеспечивает взаимодействие между различными размерностями эмбеддингов
- Работает независимо для каждого токена в последовательности
Пример использования:
---------------------
>>> ffn = FeedForward(emb_size=256, dropout=0.1, activation='gelu')
>>> x = torch.randn(2, 32, 256) # [batch, seq_len, emb_size]
>>> y = ffn(x)
>>> print(y.shape) # torch.Size([2, 32, 256])
Args:
emb_size (int): размерность входных эмбеддингов
dropout (float): вероятность(dropout)
activation (str): нелинейная функция (relu, gelu, gelu_exact)
Пояснения:
----------
- FeedForward не использует позицию токена — это МLP, применяемый к каждому токену независимо.
- Длина последовательности и размер батча не имеют значения (broadcast/reshape по [-2, -1]).
- Используется во всех декодерах/энкодерах трансформеров.
Подробнее смотри:
-----------------
- Vaswani et al., "Attention is All You Need": https://arxiv.org/abs/1706.03762
- GELU: https://arxiv.org/abs/1606.08415
- SwiGLU (PaLM, Llama): https://arxiv.org/abs/2002.05202
Пример:
>>> ff = FeedForward(emb_size=512, dropout=0.1)
>>> x = torch.randn(32, 10, 512)
>>> output = ff(x)
>>> print(output.shape) # torch.Size([32, 10, 512])
"""
def __init__(self, emb_size: int, dropout: float = 0.1, activation: str = "relu"):
"""
Инициализация слоя Feed Forward Network.
Инициализация FeedForward блока для трансформера.
Args:
emb_size: Размерность входных эмбеддингов
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
Аргументы:
----------
emb_size: int
Размерность входного и выходного эмбеддинга модели.
dropout: float, по умолчанию 0.1
Dropout после линии и/или активации (уменьшает переобучение).
activation: str, по умолчанию 'gelu'
Какая нелинейность использовать ('gelu', 'silu', 'relu' и т.д.).
inner_dim: int, опционально
Размер скрытого слоя (по умолчанию 4 * emb_size, как в оригинальном Transformer).
Внутри:
-------
- Задает структуру: Linear → Activation → Dropout → Linear → Dropout.
"""
super().__init__()
# Первый линейный слой (расширение размерности)
@@ -72,13 +91,23 @@ class FeedForward(nn.Module):
def forward(self, x: torch.Tensor):
"""
Прямой проход через слой Feed Forward Network.
Прямой проход через FeedForward блок.
Args:
x: Входной тензор размерности [batch_size, seq_len, emb_size]
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [..., emb_size] (используется на каждом токене отдельно!)
Returns:
Тензор той же размерности, что и входной
Возвращает:
-----------
torch.Tensor — выход такой же формы, как вход (только последняя размерность сохраняется).
Пример:
-------
>>> ffn = FeedForward(emb_size=256)
>>> x = torch.randn(8, 16, 256)
>>> y = ffn(x)
>>> y.shape # [8, 16, 256]
"""
# Сохраняем dtype входных данных
input_dtype = x.dtype

140
llm/src/llm/core/geglu.py Normal file
View File

@@ -0,0 +1,140 @@
import torch
from torch import nn
from llm.core.gelu import GELU
class GeGLU(nn.Module):
"""
GeGLU (Gated GELU Linear Unit) — эффективная нелинейность для feed-forward блоков в современных трансформерах.
Назначение:
-----------
GeGLU — это вариант GLU (Gated Linear Unit), где «шлюз» реализован через GELU-активацию,
а затем поэлементно перемножается с другим линейным преобразованием. Такой gating-механизм позволяет повысить
выразительность MLP-блока и ускорить обучение, что подтверждено экспериментами на LLM (см. PaLM, LLaMA, T5).
Формула:
--------
GeGLU(x) = GELU(W_g x + b_g) ⊙ (W_u x + b_u) W_d + b_d
(здесь W_g, W_u, W_d — матрицы весов; GELU применяется к одной ветке, ⊙ — поэлементное умножение)
Структура блока:
----------------
1. gate = GELU(Linear_gate(x)) # ветка gating-а, shape [batch, seq, 4×emb]
2. up = Linear_up(x) # ветка передачи, shape [batch, seq, 4×emb]
3. out = gate * up # поэлементно, реализует динамическую фильтрацию информации
4. out = Linear_down(out) # проекция обратно в исходное пространство
5. out = Dropout(out) # регуляризация
Основные преимущества:
----------------------
- Позволяет эффективно обучать глубокие трансформеры (см. PaLM, LLaMA).
- Обеспечивает плавные градиенты за счёт GELU и gating-эффекта.
- Используется во многих современных LLM вместо обычных FFN или простых GLU.
Аргументы конструктора:
-----------------------
emb_size : int
Размер эмбеддинга (input и output).
dropout : float, по умолчанию 0.1
Dropout к финальному выходу (примерно 0.1-0.2 для регуляризации).
Пример использования:
---------------------
>>> geglu = GeGLU(emb_size=512, dropout=0.1)
>>> x = torch.randn(8, 16, 512)
>>> y = geglu(x)
>>> print(y.shape) # torch.Size([8, 16, 512])
Литература:
-----------
- Shazeer N., "GLU Variants Improve Transformer", 2020: https://arxiv.org/abs/2002.05202
- PaLM: https://arxiv.org/abs/2204.02311
- LLaMA: https://arxiv.org/abs/2302.13971
- T5: https://arxiv.org/abs/1910.10683
"""
def __init__(self, emb_size: int, dropout: float = 0.1):
"""
Инициализация блока GeGLU.
Создаёт три последовательных линейных слоя и задаёт GELU в качестве активации для ветки gating,
а также финальный dropout. Все размеры согласованы так, чтобы реализовать формулу GeGLU (см. описание класса).
Аргументы:
----------
emb_size : int
Размерность входного и выходного скрытого пространства (hidden size).
Данная величина определяет размерность эмбеддинга для всех внутренних вычислений.
Обычно равна размеру скрытого слоя трансформера.
dropout : float, по умолчанию 0.1
Вероятность отключения нейронов после выхода из блока (регуляризация).
Рекомендуемое значение: 0.1 (или чуть больше для небольших моделей).
Внутри:
-------
- self._gate: Linear слой размерности [emb_size, 4 * emb_size], ветка gating (проходит через GELU)
- self._up: Linear слой размерности [emb_size, 4 * emb_size], ветка передачи ("пропускная")
- self._down: Linear слой сжатия обратно к emb_size
- self._activation: Активация GELU для gating-ветки
- self._dropout: Dropout для выходного тензора
Пример:
-------
>>> block = GeGLU(emb_size=256, dropout=0.1)
>>> print(block)
"""
super().__init__()
self._gate = nn.Linear(emb_size, 4 * emb_size)
self._up = nn.Linear(emb_size, 4 * emb_size)
self._down = nn.Linear(4 * emb_size, emb_size)
self._activation = GELU()
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через блок GeGLU.
Для входного тензора скрытых состояний x реализует последовательность операций:
1. Gating-ветка: линейное преобразование → GELU-активация
2. Пропускная ветка: линейное преобразование
3. Поэлементное умножение результатов обеих веток (gating)
4. Проекция через Linear обратно к emb_size
5. Dropout результата для регуляризации
Математически:
--------------
gate = GELU(W_g·x + b_g)
up = W_u·x + b_u
out = gate * up
out = W_d·out + b_d
out = Dropout(out)
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size]
(или любой совместимой формы, где последняя ось — emb_size).
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_len, emb_size], прошедший через структуру GeGLU.
Пример:
-------
>>> y = geglu(x)
>>> print(y.shape) # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Ветка gating строит masк для динамической фильтрации информации.
- Такой тип блока эффективно используется как замена обычного FFN в современных LLM.
"""
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
out = up_out * activation_out # поэлементное!
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)

View File

@@ -1,27 +1,72 @@
import torch
from torch import nn
import math
class GELU(nn.Module):
"""
Гауссовская Эрф-активация (GELU, Gaussian Error Linear Unit).
GELU (Gaussian Error Linear Unit) — современная сглаженная функция активации для нейросетей.
Научная суть:
- Одна из самых популярных smooth активаций для трансформеров.
- Дает более гибкие аппроксимации, чем ReLU/SiLU, улучшает flow градиентов для больших LLM.
- Используется в BERT, GPT, GPT2 и почти всех современных NLP-моделях.
Формула:
GELU(x) = 0.5 * x * (1 + tanh(\sqrt{2/π} * (x + 0.044715 x³)))
Подробнее: Hendrycks & Gimpel, "Gaussian Error Linear Units (GELUs)", arXiv:1606.08415
Пример:
Мотивация и назначение:
-----------------------
- GELU используется во всех современных трансформерах (BERT, GPT, Llama) вместо ReLU, поскольку лучше передает градиенты и даёт более "мягкое" обучение.
- Формирует плавный переход между активированным и неактивированным состоянием, что улучшает устойчивость и общую производительность больших моделей.
- Дает возможность обучению «решать», насколько сильно и в каких диапазонах нужно передавать сигнал (в отличие от жёсткого ReLU).
Математическая формула:
-----------------------
GELU(x) = 0.5 * x * (1 + tanh( sqrt(2/pi) * (x + 0.044715 * x^3) ))
- Статья (Hendrycks & Gimpel, 2016): https://arxiv.org/abs/1606.08415
- В PyTorch с версии 1.4+ встроена как torch.nn.functional.gelu и torch.nn.GELU.
Как это работает:
-----------------
- Для каждого входного значения x:
- x при больших значениях (большие положительные) почти полностью передается дальше.
- x при малых (или сильно отрицательных) "заглушается" к нулю.
- На промежуточных значениях — плавный переход.
- Является аппроксимацией случайного бинома с гауссовским шумом.
Args:
-----
Нет learnable параметров — GELU работает одинаково для всех входов.
Пример использования:
---------------------
>>> gelu = GELU()
>>> y = gelu(torch.tensor([-1.0, 0.0, 1.0]))
>>> print(y)
>>> x = torch.tensor([-2.0, 0.0, 2.0])
>>> print(gelu(x)) # тензор из плавно переходящих значений
References:
-----------
- Hendrycks & Gimpel: https://arxiv.org/abs/1606.08415
- BERT, GPT-2 papers (везде используется GELU)
"""
def __init__(self):
super().__init__()
self.sqrt_2_over_pi = torch.sqrt(torch.tensor(2.0) / math.pi)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (1 + torch.tanh(
self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))
))
"""
Прямой проход через GELU-активацию.
Args:
-----
x : torch.Tensor
Любой входной тензор.
Returns:
--------
torch.Tensor — тензор той же формы, где к каждому элементу применён GELU.
Пример:
-------
>>> gelu = GELU()
>>> x = torch.linspace(-3, 3, 7)
>>> y = gelu(x)
"""
return (
0.5
* x
* (1 + torch.tanh(self.sqrt_2_over_pi * (x + 0.044715 * torch.pow(x, 3))))
)

View File

@@ -0,0 +1,188 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.multi_query_attention import MultiQueryAttention
from llm.core.rms_norm import RMSNorm
from llm.core.geglu import GeGLU
class GemmaDecoder(nn.Module):
"""
GemmaDecoder — декодерный блок архитектуры Gemma (Google DeepMind, 2024).
Назначение:
-----------
Данный блок реализует одну «ячейку» декодерного стека в модели Gemma. Архитектура схожа с современными LLM (Llama/Mistral),
но имеет уникальные особенности attention и feed-forward слоёв, соответствующие спецификации Gemma.
Архитектурные компоненты:
-------------------------
- LayerNorm или RMSNorm
- Multi-head self-attention (обычно Multi-Query Attention)
- Skip connection (остаточное сложение)
- Feed-forward блок (может включать SwiGLU, GeGLU или классический FFN)
- Повторная нормализация
- Dropout (регуляризация на уровне attention и feed-forward)
Алгоритм прямого прохода:
-------------------------
1. norm1_out = LayerNorm(x)
2. attention_out = Attention(norm1_out, ...)
3. resid1 = attention_out + x
4. norm2_out = LayerNorm(resid1)
5. ffn_out = FeedForward(norm2_out)
6. output = ffn_out + resid1
Теоретические детали:
---------------------
- В Gemma используются техники оптимизации памяти и ускорения инференса (например, shared K/V-головы, Rope, кастомные FFN).
- Поддержка кэширования attention для ускорения генерации (KV cache).
- Блок проектирован для использования в стеке, повторяется N раз во всей LLM.
Аргументы конструктора:
----------------------
num_q_heads : int
Число голов query (Query Heads) для attention.
num_kv_heads : int
Число ключевых/значенческих голов (Key/Value Heads).
emb_size : int
Размерность скрытого пространства (embedding dim).
head_size : int
Размерность одной attention-головы.
max_seq_len : int
Максимальная длина последовательности (ограничение на causal mask).
dropout : float, optional
Dropout для регуляризации (примерно 0.00.1).
rope : RoPE, optional
Позиционное кодирование Rotary Position Embedding.
Пример использования:
---------------------
>>> decoder = GemmaDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... dropout=0.1,
... rope=rope_obj
... )
>>> x = torch.randn(2, 24, 256)
>>> out, cache = decoder(x, mask=None, use_cache=True, cache=None)
>>> print(out.shape) # torch.Size([2, 24, 256])
Литература и ссылки:
--------------------
- Gemma (официальный релиз): https://ai.google.dev/gemma
- Gemma paper: https://arxiv.org/abs/2403.07794
- Rotary Embedding: https://arxiv.org/abs/2104.09864
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
"""
def __init__(self,
num_q_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор слоя GemmaDecoder.
Производит инициализацию всех подслоёв (нормализация, multi-head или multi-query attention, feed-forward блок, Dropout)
согласно архитектуре декодера Gemma. Обеспечивает поддержку rotary-позиционирования, обучения и inference с caching.
Аргументы:
----------
num_q_heads : int
Количество query-голов в attention (определяет степень параллелизма внимания).
emb_size : int
Размер пространства эмбеддинга (embedding dim, input/output размерность слоя).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимальная длина последовательности, для которой поддерживается attention и маскирование.
rope : RoPE
Объект для rotary positional encoding (позиционное кодирование для attention).
dropout : float, default=0.1
Dropout после attention и feed-forward для регуляризации (обычно 0.00.1).
Внутри:
-------
- Инициализируются все слои norm, attention, rope, FFN, остаточные соединения.
- Строится causal-маска автоагрессивного attention (если требуется).
- Гибко поддерживает работу как на training, так и для быстрых inference/генерации.
Пример:
-------
>>> decoder = GemmaDecoder(
... num_q_heads=8, emb_size=512, head_size=64, max_seq_len=1024, rope=rope_obj, dropout=0.05
... )
"""
super().__init__()
self._heads = MultiQueryAttention(
num_q_heads=num_q_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
dropout=dropout
)
self._ff = GeGLU(emb_size=emb_size, dropout=dropout)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через GemmaDecoder.
Последовательно реализует:
- Нормализацию входа (обычно RMSNorm или LayerNorm)
- Self-attention (multi-query или multi-head, с опциональной маской и кэшем)
- Остаточное сложение (skip connection)
- Вторую нормализацию
- Feed-Forward-блок (например, GeGLU/SwiGLU)
- Ещё одно residual сложение
Поддерживает autoregressive режим с caching (KV-слоты attention для ускорения генерации).
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_length, emb_size].
mask : torch.Tensor, optional
Attention mask (например, causal или padding mask). Если None, используется встроенная causal mask.
use_cache : bool, по умолчанию True
Если True — возвращается кэш KV для ускорения autoregressive генерации.
cache : list, optional
Кэш предыдущих ключей/значений attention (если используется при инференсе).
Возвращает:
-----------
Tuple[torch.Tensor, cache]:
- Выход декодера с той же формой [batch_size, seq_length, emb_size]
- Кэш attention (если use_cache=True), иначе None
Пример:
-------
>>> out, new_cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- mask используется для ограничения внимания (напр., каузальный режим GPT/LLM).
- Для ускорения в режиме генерации рекомендуется использовать use_cache=True + передавать cache.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

View File

@@ -0,0 +1,138 @@
from torch import nn
import torch
from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention
class GptDecoder(nn.Module):
"""
Decoder — базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
Назначение:
-----------
- Инкапсулирует архитектуру: norm → multi-head self-attention → residual → norm → feed-forward → residual
- Подходит как для LLM/GPT, так и для любых autoregressive sequence моделей.
- Использует masked self-attention: каждый токен видит только предыдущие (никакого \"заглядывания в будущее\").
- Стабильность обеспечивается через residual connections и LayerNorm после каждого sub-layer.
Почему это важно?
-----------------
- Все современные языковые модели состоят из подобных блоков, соединённых в стек.
- Алгоритм residual+norm позволяет проще обучать очень глубокие сети.
- Разделение на attention+FFN дает и локальные, и глобальные взаимодействия между токенами.
Формула работы (псевдокод):
---------------------------
y1 = norm1(x)
attn_out = Attention(y1)
x2 = x + attn_out # residual
y2 = norm2(x2)
ffn_out = FFN(y2)
out = x2 + ffn_out # residual
Архитектурные особенности:
--------------------------
- Поддержка внимания с маской (causal mask или произвольная attention mask)
- Residual connections для каждого блока (attention, FFN)
- Pre-LN (norm перед каждым подблоком)
- Зависит от переданных блоков self_attention и feed_forward, а не их реализации
References:
-----------
- Vaswani et al., \"Attention is All You Need\" (2017): https://arxiv.org/abs/1706.03762
- Illustrated Transformer: https://jalammar.github.io/illustrated-transformer/
- Transformer Circuits (дружественное описание): https://transformer-circuits.pub/2021/framework/index.html
Пример:
-------
>>> decoder = Decoder(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
>>> x = torch.randn(1, 10, 512)
>>> out = decoder(x)
>>> print(out.shape) # torch.Size([1, 10, 512])
"""
def __init__(
self,
num_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
dropout: float = 0.1,
):
"""
Инициализация стандартного decoder-блока для Transformer.
Аргументы:
----------
num_heads: int
Количество attention голов (как делить emb_size на heads)
emb_size: int
Размерность эмбеддингов (и входа и выхода)
head_size: int
Размерность одной attention-головы (emb_size = num_heads * head_size)
max_seq_len: int
Максимальная длина последовательности (важно для mask)
dropout: float, default=0.1
Dropout после внимания и FFN
Внутри:
-------
- Создаёт слой MultiHeadAttention (masked/casual)
- Создаёт двухслойный FeedForward (SwiGLU или GELU)
- Применяет 2 слоя LayerNorm для стабилизации градиентов
- Все блоки реализованы как PyTorch-модули
"""
super().__init__()
self._heads = MultiHeadAttention(
num_heads=num_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
dropout=dropout,
)
self._ff = FeedForward(emb_size=emb_size, dropout=dropout)
self._norm1 = nn.LayerNorm(emb_size)
self._norm2 = nn.LayerNorm(emb_size)
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
cache: list = None,
attention_mask=None
) -> tuple:
"""
Один прямой проход через Transformer decoder block.
Аргументы:
----------
x : torch.Tensor
Входной тензор [batch_size, seq_len, emb_size]
mask : torch.Tensor, optional
Attention/causal mask (по умолчанию None, тогда будет casual mask по длине seq_len)
Возвращает:
-----------
out : torch.Tensor
Выходной тензор той же формы, что и x
Алгоритм:
---------
- Применяем attention к нормализованному входу (layernorm)
- Добавляем residual-связь (attention + исходный вход)
- Применяем FFN к нормализованному результату (layernorm)
- Добавляем residual-связь (ffn + предыдущий выход)
"""
# Self-Attention блок
attention, kv_caches = self._heads(x, attention_mask, use_cache=use_cache, cache=cache)
out = self._norm1(attention + x)
# FeedForward блок
ffn_out = self._ff(out)
result = self._norm2(ffn_out + out)
if use_cache:
return (result, kv_caches)
else:
return (result, None)

View File

@@ -0,0 +1,413 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.rope import RoPE
class GroupedQueryAttention(nn.Module):
"""
Grouped Query Attention (GQA)
=============================
Что такое Grouped Query Attention?
----------------------------------
Это разновидность многоголового внимания (multi-head), где для Q (query) голов может быть больше, чем для K/V (key/value) голов:
вместо стандартного MHA (num_q_heads == num_kv_heads) — меньшее число K/V разделяет информацию для всех Q.
Такой подход экономит память и ускоряет инференс, сохраняя высокое качество внимания (используется например в Mistral, Llama-2, GPT-4 и др.).
Зачем это нужно?
----------------
- Сокращает количество вычислений и размер KV-кэша в больших LLM.
- Позволяет эффективно масштабировать число attention-глав для моделирования сложных связей, не увеличивая размер всех матриц.
Как работает?
-------------
1. Q формируется для каждого query-head (их много)
2. K и V вычисляется только для меньшего числа KV-heads (обычно в 2-4 раза меньше, чем Q)
3. К/V heads дублируются (repeat) так, чтобы на каждую Q-head был свой набор
4. Всё внимание (Q,K,V) — стандартное scaled dot-product, только более эффективно и с компрессией
Поддержка дополнительных фич:
-----------------------------
- Rotary Position Encoding (RoPE) для Q и K (для относительной позиции)
- Sliding-window attention mask (можно ограничить исторический контекст, как в Mistral)
- Кэширование Q/K/V (ускоряет генерацию автоагретивно)
Аргументы конструктора:
-----------------------
num_q_heads: int — количество query голов (Q)
num_kv_heads: int — количество key/value голов (обычно меньше Q)
emb_size: int — embedding размерность
head_size: int — размер каждой attention-head
max_seq_len: int — максимальная длина последовательности
window_size: int — размер sliding window (макс. количество токенов в контексте внимания)
rope: RoPE (по желанию) — если задан, то будет применяться RoPE для Q и K
dropout: float — dropout после линейной проекции
Пример использования:
---------------------
>>> gqa = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
>>> x = torch.randn(2, 128, 256)
>>> y, cache = gqa(x)
>>> print(y.shape) # torch.Size([2, 128, 256])
Где прочитать подробнее:
------------------------
- LlamaV2 (Section 2.3): https://arxiv.org/abs/2307.09288
- Mistral: https://arxiv.org/abs/2310.06825
- \"Self-attention with linear complexity\" (Vila et al.): https://arxiv.org/abs/2302.05442
- Обзор: https://huggingface.co/blog/mistral
"""
def __init__(
self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
window_size: int,
rope: RoPE = None,
dropout: float = 0.1,
):
"""
Инициализация слоя Grouped Query Attention (GQA).
Этот конструктор задаёт архитектуру эффективного внимания, где Q-голов может быть больше, чем KV-голов.
Это экономит память/вычисления и позволяет реализовать сдвигающееся "окно" внимания (Mistral-style).
Аргументы:
----------
num_q_heads : int
Количество Query attention heads (чаще всего кратно num_kv_heads, напр. 8/2, 12/4).
Чем больше — тем богаче контекстное окно каждой позиции.
num_kv_heads : int
Количество Key/Value attention heads (обычно 2-4, иногда меньше, чем Query).
В современных LLM принято уменьшать их число для оптимизации скорости/кэша.
emb_size : int
Размерность входного embedding (общий размер вектора на токен).
head_size : int
Размерность одной головы внимания.
Требуется: num_q_heads * head_size == emb_size (иначе ошибка).
max_seq_len : int
Максимальная поддерживаемая длина входной последовательности; определяет размер триангулярной (causal/sliding window) маски.
window_size : int
Размер "скользящего окна" истории — сколько токенов учитывается при слепом внимании (как у Mistral).
Чем меньше значение, тем локальнее работает внимание (и меньше память/время).
rope : RoPE, опционально
Если задан — применяется Rotary Positional Encoding к Q и K для относительного позиционного кодирования.
dropout : float, по умолчанию 0.1
Dropout после линейной проекции attention (обычно 0.1, помогает борьбе с переобучением).
Что создаётся внутри:
---------------------
- Линейные слои для получения Q, K, V из embedding.
- Буфер для causal/sliding window mask (матрица масок в зависимости от window_size и max_seq_len).
- Линейный слой для финального преобразования (объединение всех голов и возврат к emb_size).
- Dropout перед возвратом.
Пример:
-------
>>> attn = GroupedQueryAttention(
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
... max_seq_len=1024, window_size=256, dropout=0.1)
"""
super().__init__()
self._num_heads = num_q_heads
self._num_kv_heads = num_kv_heads
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._window_size = window_size
self._q = nn.Linear(emb_size, self._num_heads * head_size)
self._k = nn.Linear(emb_size, num_kv_heads * head_size)
self._v = nn.Linear(emb_size, num_kv_heads * head_size)
# Создание causal маски
mask = self._create_sliding_window_mask(max_seq_len, self._window_size)
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
self._layer = nn.Linear(head_size * self._num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
use_cache: bool = True,
cache: list = None,
):
"""
Шаг внимания в режиме Grouped Query Attention —
реализует эффективное многооконное внимание с раздельными Q/KV и sliding/casual mask.
Что происходит в этом методе:
-----------------------------
- Преобразует входной тензор x (токеновые эмбеддинги) в Q, K, V-матрицы с учётом разного числа голов для Q и KV.
- Формирует attention "маску" для sliding window, если нужно ограничить историю.
- Применяет RoPE (если задан) к Q и K, вносит позиционную информацию.
- При работе с кэшем дополняет ключи и значения предыдущими (ускоряет генерацию).
- Повторяет K/V головы для соответствия количеству Q (чтобы на каждую Q-head приходился свой KV).
- Считает обычное scaled dot-product внимание, применяет маску (не даёт видеть будущее, как и в autoregressive).
- Softmax, смешивание V на основе attention, объединение всех голов.
- Dropout и финальное линейное преобразование обратно к emb_size.
Аргументы:
----------
x : torch.Tensor
Входной тензор размера [batch, seq_len, emb_size]
mask : torch.Tensor, по умолчанию None
Матричная маска для внимания (можно передать внешнюю или использовать встроенную sliding window mask)
use_cache : bool, по умолчанию True
Нужно ли использовать/возвращать кэш KV для быстрых автогенераций.
cache : list, опционально
Ранее сохранённый кэш KV (используется для инференса по одному токену)
Возвращает:
-----------
- output: torch.Tensor формы [batch, seq_len, emb_size]
- kv_cache: кэш новых KV (если use_cache=True), иначе None
Важно:
-------
- Реализует Mistral-style attention: к каждой Q-head в итоге “приписан” собственный (но потенциально дублированный) KV-head.
- Sliding window ограничивает область вижимости в attention (ускоряет генерацию на длинных последовательностях).
- Использование RoPE опционально — но необходимо для современных архитектур LLM.
Пример:
-------
>>> attn = GroupedQueryAttention(num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32, max_seq_len=1024, window_size=256)
>>> x = torch.randn(2, 128, 256)
>>> y, kv_cache = attn(x)
>>> print(y.shape) # torch.Size([2, 128, 256])
"""
batch_size, seq_len, emb_size = x.shape
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# Шаг 2: Изменение формы для multi-head
# [batch_size, seq_len, num_heads * head_size]
# -> [batch_size, seq_len, num_heads, head_size]
# Измените форму запроса (query) на batch_size × num_q_heads × seq_len × head_size.
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
# Измените форму ключа (key) и значения (value) на batch_size × num_kv_heads × seq_len × head_size.
k = k.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
v = v.reshape(batch_size, seq_len, self._num_kv_heads, self._head_size)
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[2]
start_pos = cache_len
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
if self._rope is not None:
# Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
# 5. Кэширование (для autoregressive generation)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
v = torch.cat([v_cache, v], dim=2)
# Если use_cache == True, то сохраните матрицы ключа и значения для кэша (это нужно сделать до дублирования голов).
#if use_cache == True:
# # Обрезаем до последних window_size токенов
# k_to_cache = k[:, :, -self._window_size:, :]
# v_to_cache = v[:, :, -self._window_size:, :]
# kv_cache = (k_to_cache, v_to_cache)
# Продублируйте головы в тензорах ключа (key) и значения (value), чтобы получился тензор размера на batch_size × num_q_heads × seq_len × head_size.
#k = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
#v = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
k_expanded = self._repeat_kv_heads(k, self._num_heads, self._num_kv_heads)
v_expanded = self._repeat_kv_heads(v, self._num_heads, self._num_kv_heads)
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
# И разделить все значения в матрице внимания на корень из head_size.
scores = q @ k_expanded.transpose(-2, -1) / (self._head_size ** 0.5)
# 8. Применение маски
k_seq_len = k_expanded.size(2) # Длина K после concat с кэшем
if cache is None:
# Случай 1: Без кэша - полная квадратная маска
# scores: [B, H, seq_len, seq_len]
# Применяем маску [:seq_len, :seq_len]
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len],
float("-inf")
)
# Применить к матрице внимания (построчно) функцию Softmax.
weights = F.softmax(scores, dim=-1)
# Перемножим матрицу внимания и матрицу значения.
x_out = weights @ v_expanded # [B, T, hs]
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
# Transpose обратно и concatenate heads
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
x_out = x_out.contiguous() # Важно для reshape!
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
# Пропустите получившийся тензор через последний линейный слой.
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
# 4. Применяем dropout для регуляризации
output = self._dropout(projected_output)
if use_cache:
# Обрезаем оригинальный K и V (до дублирования)
k_to_cache = k[:, :, -self._window_size:, :]
v_to_cache = v[:, :, -self._window_size:, :]
kv_cache = (k_to_cache, v_to_cache)
return output, kv_cache
else:
return output, None
def _repeat_kv_heads(
self,
kv: torch.Tensor,
num_q_heads: int,
num_kv_heads: int
) -> torch.Tensor:
"""
Приводит число голов K/V к числу голов Q путём поэлементного повторения (tile) KV-голов.
Зачем это нужно?
----------------
В Grouped Query Attention (Mistral, Llama-2, GPT-4 и др.) обычно num_kv_heads < num_q_heads.
Чтобы каждая Query-head могла смотреть на свою собственную (пусть и общую) KV, мы "нарезаем" или повторяем KV столько раз, сколько требуется — это экономит память и ускоряет генерацию.
Алгоритм:
---------
- kv имеет форму [batch_size, num_kv_heads, seq_len, head_size]
- Для каждого KV-head делается n_repeat = num_q_heads // num_kv_heads по head-axis (обычно целое)
- На выходе форма [batch_size, num_q_heads, seq_len, head_size], где каждый KV-head дублирован для нужного количества Q-heads.
Args:
-----
kv : torch.Tensor
Входной тензор KV (обычно после linear layer on эмбеддинги), размер [batch_size, num_kv_heads, seq_len, head_size]
num_q_heads : int
Сколько должно быть Q-голов (их больше!)
num_kv_heads : int
Сколько KV-голов было (их меньше!)
Returns:
--------
torch.Tensor формы [batch_size, num_q_heads, seq_len, head_size], где KV-головы повторены как требуется.
Пример:
-------
num_q_heads = 8, num_kv_heads = 2
[KV0, KV1] -> [KV0, KV0, KV0, KV0, KV1, KV1, KV1, KV1]
# Каждый KV-head дублируется 4 раза, чтобы покрыть все 8 Q-heads.
"""
batch_size, num_kv_heads, seq_len, head_size = kv.shape
if num_q_heads == num_kv_heads:
# Нет необходимости дублировать
return kv
# Вычисляем сколько раз нужно повторить каждую голову
num_repeats = num_q_heads // num_kv_heads
# repeat_interleave дублирует каждую голову num_repeats раз
# [B, num_kv_heads, S, hs] -> [B, num_q_heads, S, hs]
# [B, num_kv_heads, S, hs] -> [B, num_kv_heads, 1, S, hs]
kv = kv.unsqueeze(2)
# [B, num_kv_heads, 1, S, hs] -> [B, num_kv_heads, num_repeats, S, hs]
kv = kv.repeat(1, 1, num_repeats, 1, 1)
# [B, num_kv_heads, num_repeats, S, hs] -> [B, num_q_heads, S, hs]
kv = kv.reshape(batch_size, num_q_heads, seq_len, head_size)
return kv
def _create_sliding_window_mask(
self,
max_seq_len: int,
window_size: int,
device: torch.device = None
) -> torch.Tensor:
"""
Создаёт маску для Sliding Window Attention (ограниченного окна внимания).
Зачем нужна эта маска?
----------------------
В современных LLM (например, Mistral) self-attention работает не по всей истории, а только в узком "скользящем окне":
каждый токен видит только предшествующие (или соседние) токены на расстоянии window_size.
Это ускоряет инференс на длинных текстах и экономит память, но сохраняет ключевые зависимости в пределах окна.
Как работает алгоритм:
----------------------
- Для каждого токена mask[i, j] == True только если токен j находится СЛЕВА и не дальше, чем window_size позиций (или сам i).
- Главное: mask всегда "нижнетреугольная" (causal), плюс полоса шириной window_size вдоль главной диагонали.
- Всё за пределами окна — False (attention нельзя).
Args:
-----
max_seq_len : int
Максимальная длина последовательности (размер будущей attention-матрицы).
window_size : int
Сколько предыдущих токенов доступно для внимания у каждого шага (вкл. сам себя).
device : torch.device, опционально
На каком устройстве (cpu/gpu) создавать маску.
Returns:
--------
torch.Tensor
Маска внимания формы [max_seq_len, max_seq_len], где True — допускается внимание (иначе False).
Пример:
-------
>>> mask = create_sliding_window_mask(8, 3)
>>> print(mask.int())
tensor([[1, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0, 0, 0],
[0, 1, 1, 1, 0, 0, 0, 0],
[0, 0, 1, 1, 1, 0, 0, 0],
[0, 0, 0, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1, 1, 1]])
"""
row_indices = torch.arange(max_seq_len, device=device).unsqueeze(1) # [max_seq_len, 1]
col_indices = torch.arange(max_seq_len, device=device).unsqueeze(0) # [1, max_seq_len]
causal_mask = col_indices <= row_indices
window_mask = (row_indices - col_indices) <= window_size
mask = causal_mask & window_mask
return mask

View File

@@ -1,103 +0,0 @@
import torch
from torch import nn
import torch.nn.functional as F
from math import sqrt
from .rope import RoPE
class HeadAttention(nn.Module):
"""
Одноголовый механизм внимания (scaled dot-product attention) — фундаментальный строительный блок всех современных Transformer.
Научная суть:
- Attention учит модель самостоятельно "выбирать" важные связи между словами, независимо от их положения.
- Механизм causal mask гарантирует невозможность "заглядывания в будущее" при генерации (авторегрессия).
Формула:
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) · V
(Q — запросы, K — ключи, V — значения; d_k — размерность ключа)
Поддерживает Rotary Position Encoding (RoPE) для относительного позиционного кодирования.
Args:
emb_size (int): размер входного эмбеддинга
head_size (int): размерность attention-головы
max_seq_len (int): максимальная длина последовательности
rope (RoPE, optional): экземпляр RoPE для позиций
Примечания:
- Использует нижнетреугольную маску для предотвращения "заглядывания в будущее"
- Автоматически адаптируется к разным версиям PyTorch
- Поддерживает batch-обработку входных данных
Пример использования:
>>> attention = HeadAttention(emb_size=64, head_size=32, max_seq_len=128)
>>> x = torch.randn(1, 10, 64)
>>> output, _ = attention(x)
>>> print(output.shape) # torch.Size([1, 10, 32])
"""
def __init__(self, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None):
super().__init__()
self._emb_size = emb_size
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
# Линейные преобразования для Q, K, V
self._k = nn.Linear(emb_size, head_size)
self._q = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer('_tril_mask', mask.bool() if hasattr(torch, 'bool') else mask.byte())
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: tuple = None) -> tuple:
"""
Прямой проход через слой внимания.
Аргументы:
x (torch.Tensor): Входной тензор формы [batch_size, seq_len, emb_size]
Возвращает:
torch.Tensor: Выходной тензор формы [batch_size, seq_len, head_size]
Исключения:
ValueError: Если длина последовательности превышает max_seq_len
Пример внутренних преобразований:
Для входа x.shape = [2, 5, 64]:
1. Q/K/V преобразования -> [2, 5, 32]
2. Scores = Q·K^T -> [2, 5, 5]
3. После маски и softmax -> [2, 5, 5]
4. Умножение на V -> [2, 5, 32]
"""
seq_len = x.shape[1]
if seq_len > self._max_seq_len:
raise ValueError(f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}")
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=1) # [B, cache_len + T, hs]
v = torch.cat([v_cache, v], dim=1) # [B, cache_len + T, hs]
scores = q @ k.transpose(-2, -1) / sqrt(self._head_size)
if cache is None:
scores = scores.masked_fill(~self._tril_mask[:seq_len, :seq_len], float('-inf'))
weights = F.softmax(scores, dim=-1)
x_out = weights @ v # [B, T, hs]
if use_cache is True:
return (x_out, (k, v))
else:
return (x_out, None)

View File

@@ -0,0 +1,134 @@
import torch
from torch import nn
from llm.core.rms_norm import RMSNorm
from llm.core.swi_glu import SwiGLU
from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
class MistralDecoder(nn.Module):
"""
MistralDecoder — стек декодирующих блоков, реализующий архитектуру Mistral-style Transformer.
Назначение:
-----------
Этот класс описывает один или несколько блоков декодера, включающих Grouped Query Attention (GQA),
sliding window attention и SwiGLU feed-forward, как реализовано в моделях Mistral и Llama 2.
Ключевые особенности архитектуры:
---------------------------------
- Использует GQA: для каждого токена вычисляется attention c раздельным числом Q и KV голов (сильно ускоряет LLM).
- Sliding Window Attention: внимание ограничено окном из window_size элементов (ускоряет обработку длинных текстов).
- Rotary Positional Embedding (RoPE): позиционная информация интегрируется вращением Q/K.
- RMSNorm перед и после внимания и FFN (устойчивое обучение).
- SwiGLU в качестве нелинейности вместо стандартного GELU (больше capacity в модели).
Аргументы конструктора:
-----------------------
num_layers : int — сколько блоков-декодеров в стеке
параметры GQA: num_q_heads, num_kv_heads, emb_size, head_size, max_seq_len, window_size, rope, dropout
- все они идут в каждый слой (блок) декодера
Пример использования:
---------------------
>>> decoder = MistralDecoder(
... num_q_heads=8, num_kv_heads=2, emb_size=256, head_size=32,
... max_seq_len=4096, window_size=256, rope=rope, dropout=0.1)
>>> x = torch.randn(2, 512, 256)
>>> out, cache = decoder(x)
>>> print(out.shape) # torch.Size([2, 512, 256])
Подробнее:
----------
- Mistral: https://arxiv.org/abs/2310.06825
- Llama 2: https://arxiv.org/abs/2307.09288
- Open LLM обзор: https://huggingface.co/blog/mistral
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Инициализация стека декодеров MistralDecoder.
Аргументы:
----------
num_layers : int
Сколько слоёв (декодеров/GQA-блоков) собрать в стек.
num_q_heads : int
Количество Query-heads в attention (их больше, экономит память).
num_kv_heads : int
Количество Key/Value-heads в attention (их меньше для быстрой генерации).
emb_size : int
Размерность embedding (должна делиться на num_q_heads без остатка).
head_size : int
Размер одного attention head.
max_seq_len : int
Максимально обрабатываемая длина последовательности.
window_size : int
Размер окна для sliding window attention.
rope : RoPE
Rotary Positional Embedding для Q/K.
dropout : float, опционально
Dropout на каждом attention/FFN (по умолчанию 0.1).
Внутри:
-------
- Собираются num_layers Sequential-блоков из GQA + SwiGLU + RMSNorm.
- Все параметры передаются в каждый слой (блок).
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = SwiGLU(emb_size=emb_size, dropout=dropout)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход через стек MistralDecoder.
Аргументы:
----------
x : torch.Tensor
Входные эмбеддинги (обычно shape [batch, seq_len, emb_size]).
use_cache : bool, по умолчанию True
Включить ли кэширование для ускорения генерации (авторегрессия).
cache : list, опционально
Предыдущий кеш attention-блоков (или None).
Возвращает:
-----------
out : torch.Tensor
Тензор после декодирования (shape соответствует x).
new_cache : list (или None)
Новый кэш attention для дальнейшей генерации (или None, если use_cache=False).
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

View File

@@ -0,0 +1,211 @@
from torch import nn
import torch
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
from llm.core.moe import MoE
from llm.core.rms_norm import RMSNorm
class MixtralDecoder(nn.Module):
"""
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
Назначение:
-----------
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
Архитектура блока:
------------------
- RMSNorm -> Grouped Query Attention (GQA)
- skip-connection
- RMSNorm -> MoE (SwiGLU-эксперты)
- skip-connection
Для входа `x` проходит:
1. norm1_out = RMSNorm(x)
2. attention, kv_caches = GQA(norm1_out, ...)
3. out = attention + x # residual connection
4. norm2_out = RMSNorm(out)
5. ffn_out = MoE(norm2_out)
6. return (ffn_out + out, kv_caches)
Теоретическая мотивация:
------------------------
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
Аргументы конструктора:
----------------------
num_q_heads : int
Число query-голов в attention.
num_kv_heads : int
Число key-value голов (группировка ключей/values).
emb_size : int
Скрытый размер эмбеддинга.
head_size : int
Размерность одной головы (emb_size // num_q_heads).
max_seq_len : int
Максимальная поддерживаемая длина последовательности.
num_experts : int
Количество «экспертов» (MoE).
top_k_experts : int
Сколько одновременно экспертов активируется для одного токена.
window_size : int
Размер окна внимания (используется для efficient attention).
rope : RoPE
Реализация позиционного кодирования RoPE.
dropout : float
Вероятность Dropout для регуляризации.
Пример использования:
---------------------
>>> decoder = MixtralDecoder(... параметры ...)
>>> x = torch.randn(batch, seq, emb_size)
>>> out, cache = decoder(x, mask=None, use_cache=True)
>>> out.shape
Литература и ссылки:
--------------------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Mistral paper: https://arxiv.org/abs/2310.06825
- GQA: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор декодерного блока MixtralDecoder.
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
Аргументы:
----------
num_q_heads : int
Количество голов внимания (queries) в механизме GroupedQueryAttention.
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
num_kv_heads : int
Количество групп ключей/значений (key-value heads) для GQA.
Позволяет балансировать производительность и память.
emb_size : int
Размерность эмбеддингового пространства внутри слоя (hidden).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина токенизированной последовательности.
num_experts : int
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
top_k_experts : int
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
window_size : int
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
rope : RoPE
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
dropout : float, по умолчанию 0.1
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
Пример:
-------
>>> decoder = MixtralDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... num_experts=4,
... top_k_experts=2,
... window_size=128,
... rope=rope_module,
... dropout=0.05
... )
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через декодерный блок MixtralDecoder.
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
- нормализацию (RMSNorm),
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
- остаточное сложение (residual connection),
- повторную нормализацию,
- feed-forward блок на основе Mixture-of-Experts (MoE),
- финальное остаточное сложение.
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
mask : torch.Tensor, optional
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
use_cache : bool, по умолчанию True
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
cache : list, optional
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
Возвращает:
-----------
Tuple[torch.Tensor, Any]:
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
Пример:
-------
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

229
llm/src/llm/core/moe.py Normal file
View File

@@ -0,0 +1,229 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.swi_glu import SwiGLU
class MoE(nn.Module):
"""
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
Назначение:
-----------
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
Архитектурная схема:
---------------------
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
- Dropout применяется к итоговому выходу.
Математика (коротко):
---------------------
Пусть X ∈ R^{BxSxD} — вход,
E — число экспертов,
K — число активируемых экспертов на токен.
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
Для каждого токена:
y_j = Expert_j(x)
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
Output: Y ∈ R^{BxSxD}
Аргументы конструктора:
----------------------
emb_size : int
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
num_experts : int
Общее число экспертов внутри слоя MoE.
top_k_experts : int
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
dropout : float, по умолчанию 0.1
Dropout к выходу агрегатора.
Пример использования:
---------------------
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
>>> x = torch.randn(4, 16, 512)
>>> y = moe(x)
>>> y.shape # torch.Size([4, 16, 512])
Литература:
-----------
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
"""
def __init__(
self,
emb_size: int,
num_experts: int,
top_k_experts: int,
dropout: float = 0.1,
):
"""
Конструктор слоя MoE (Mixture of Experts).
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
который будет для каждого токена определять наиболее релевантных экспертов.
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
Аргументы:
----------
emb_size : int
Размерность входных и выходных векторов (embedding size).
Определяет, над каким пространством признаков будет работать роутер и эксперты.
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
num_experts : int
Общее количество экспертов в слое MoE.
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
Пример: 8, 16, 32, 64.
top_k_experts : int
Сколько экспертов одновременно будет обрабатывать каждый токен.
Обычно 28. Меньшее значение — выше разреженность, больше экономия вычислений.
dropout : float, по умолчанию 0.1
Вероятность зануления значений на выходе после агрегации откликов экспертов.
Используется для регуляризации (борьбы с переобучением).
Пример:
-------
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
>>> print(moe)
MoE( ... )
Теория:
-------
Слой строит:
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
"""
super().__init__()
if top_k_experts > num_experts:
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
self._num_experts = num_experts
self._top_k_experts = top_k_experts
self._router = nn.Linear(emb_size, num_experts)
self._experts = nn.ModuleList([SwiGLU(
emb_size=emb_size,
dropout=dropout,
) for _ in range(num_experts)])
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через слой MoE.
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
Математически:
--------------
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
router_logits = Linear(x) ∈ ^{batch, seq, num_experts}
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
3. Каждый эксперт обрабатывает только свой поднабор токенов.
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
5. На результат применяется dropout для регуляризации.
Аргументы:
----------
x : torch.Tensor
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
с учетом softmax-весов маршрутизатора и dropout'а.
Пример:
-------
>>> y = moe(x)
>>> print(y.shape)
torch.Size([batch_size, seq_length, emb_size])
Примечание:
-----------
- Каждый токен чаще всего активирует только подмножество экспертов.
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
"""
batch_size, seq_len, emb_size = x.shape
# 1. Пропускаем через роутер
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
# 2. Отбираем топ-k экспертов для каждого токена
topk_logits, topk_indices = torch.topk(
router_logits,
k=self._top_k_experts,
dim=-1
) # topk_logits: [batch_size, seq_len, top_k]
# topk_indices: [batch_size, seq_len, top_k]
# 3. Получаем веса через softmax и нормируем
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
# 4. Создаём нулевой тензор для результата
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
# 5. Проходим по всем экспертам
for expert_id in range(self._num_experts):
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
if not expert_mask.any():
continue # Эксперт никем не выбран, переходим к следующему
# Шаг 3: Находим токены, которые выбрали этого эксперта
# (хотя бы в одной из top_k позиций)
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
# Шаг 4: Отбираем токены из x
# Отбираем токены для этого эксперта
expert_input = x[token_mask]
# Пропускаем через эксперта
# Добавляем batch dimension для SwiGLU и затем убираем
expert_output = self._experts[expert_id](
expert_input.unsqueeze(0)
).squeeze(0)
# Получаем веса для этого эксперта
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
# Находим веса: где expert_mask == True, берём соответствующий вес
weights_for_expert = torch.zeros(
batch_size, seq_len, device=x.device
)
# Для каждой позиции в топ-k
for k in range(self._top_k_experts):
mask_k = topk_indices[:, :, k] == expert_id
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
# Отбираем только веса для выбранных токенов
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
# Перемножьте выход эксперта на веса текущего эксперта.
weighted_output = selected_weights.unsqueeze(-1) * expert_output
# Помещаем результат на своё место в выходном тензоре
output[token_mask] += weighted_output
out = self._dropout(output)
return out

View File

@@ -1,122 +1,256 @@
from torch import nn
import torch
from .head_attention import HeadAttention
import torch.nn.functional as F
from .rope import RoPE
class MultiHeadAttention(nn.Module):
"""
Мультиголовый (многоголовый) механизм внимания — ключевой компонент любого Transformer.
Multi-Head Attention (Многоголовое внимание)
============================================
Научная суть:
- Модель параллельно агрегирует информацию через несколько подпространств (головы),
чтобы видеть разные связи в последовательности (разный контекст, локально/глобально).
- Каждый attention блок работает независимо, выход конкатенируется.
- Механизм предложен в статье "Attention is All You Need" (Vaswani et al., 2017).
Что такое Multi-Head Attention?
-------------------------------
Это ключевой компонент трансформеров, который позволяет "смотреть" на разные части предложения
одновременно с нескольких независимых ракурсов ("голов"). Всё, что делает Single-Head Attention — только гораздо мощнее и глубже!
Формула внимания для одной головы:
Attention(Q, K, V) = softmax(QK^T/sqrt(d_k))·V
Мультиголовый:
MultiHead(Q, K, V) = Concat([head_i])*W^O
Зачем это нужно?
----------------
- Модель может учиться одновременно учитывать и локальные, и глобальные взаимосвязи между токенами.
- Каждая attention head "ловит" свой собственный смысл/зависимости, и на выходе они объединяются.
- Это значительно улучшает понимание сложных зависимостей в тексте, особенно на длинных последовательностях.
Args:
num_heads (int): количество attention "голов"
emb_size (int): размерности входа и выхода
head_size (int): размер одной attention-головы (emb_size/num_heads)
max_seq_len (int): максимальная длина последовательности
rope (RoPE, optional): если задан, используется Rotary Positional Encoding
dropout (float): вероятность регуляризации
Как работает алгоритм? (основная схема)
---------------------------------------
1. Генерируются Q, K, V (query, key, value) — по отдельной проекции для каждой головы.
2. Для каждой головы: attention(Q, K, V) = softmax(Q·K^T / sqrt(d)) · V
3. Все головы "склеиваются" (concatenate) и прогоняются через общий финальный линейный слой.
Почему это работает?
--------------------
- Даёт трансформеру многомерное восприятие текста.
- Позволяет эффективно обучаться на задачах, где порядок и "дальние" связи важнее, чем простое соседство.
Что принимается на вход:
------------------------
- x: shape [batch, seq_len, embed_dim] — обычный batched-embed тензор.
- mask (опционально): shape [seq_len, seq_len] — маска для автогерерации или causal attention.
Какие параметры важны:
----------------------
- num_heads: сколько attention heads внутри (обычно 4, 8, 16...).
- embed_dim: исходная размерность входного тензора.
- head_size: размер одной attention-head (обычно embed_dim // num_heads).
- max_seq_len: максимальная длина последовательности для маски.
Что возвращает:
---------------
- output: shape [batch, seq_len, embed_dim] — результат применения всех attention heads.
- (опционально) cache: кэш для Q/K/V (нужно для генерации по одному токену).
Особенности реализации:
-----------------------
- Оптимизированно работает через матричные умножения (без python for циклов!).
- Включена поддержка causal attention (маска, предотвращающая «заглядывание в будущее»).
- Является ядром любого трансформера (и LLM!).
Пример использования:
>>> mha = MultiHeadAttention(num_heads=8, emb_size=512, head_size=64, max_seq_len=1024)
>>> x = torch.randn(2, 50, 512)
>>> out, cache = mha(x)
>>> print(out.shape)
---------------------
>>> attn = MultiHeadAttention(num_heads=8, embed_dim=256, head_size=32, max_seq_len=1024)
>>> x = torch.randn(2, 128, 256) # [batch, seq_len, embed_dim]
>>> context, _ = attn(x)
>>> print(context.shape) # torch.Size([2, 128, 256])
Где прочитать подробнее:
-------------------------
- Attention is All You Need (Vaswani et al, 2017): https://arxiv.org/abs/1706.03762
- Illustrated Transformer (blog): https://jalammar.github.io/illustrated-transformer/
"""
def __init__(self, num_heads: int, emb_size: int, head_size: int, max_seq_len: int, rope: RoPE = None, dropout: float = 0.1):
def __init__(
self,
num_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE = None,
dropout: float = 0.1,
):
"""
Инициализация многоголового внимания.
Конструктор многоголового внимания (MultiHeadAttention).
Параметры:
num_heads (int): Количество голов внимания. Типичные значения: 4-16
emb_size (int): Размерность входных и выходных эмбеддингов
head_size (int): Размерность каждой головы внимания (обычно emb_size // num_heads)
max_seq_len (int): Максимальная длина последовательности
dropout (float): Вероятность dropout (по умолчанию 0.1)
Здесь создаются все параметры и внутренние слои для эффективного параллельного внимания (attention) сразу из нескольких "голов".
Контрольные значения:
- num_heads * head_size должно равняться emb_size
- head_size обычно выбирают 32-128
- max_seq_len зависит от задачи (512 для BERT, 2048 для GPT-3)
Аргументы:
----------
num_heads : int
Сколько attention-heads будет внутри слоя.
Каждая “голова” учится видеть уникальные зависимости в тексте. Обычно это 4, 8, 16 и т.п.
Чем больше голов — тем богаче контекст, но и больше памяти.
emb_size : int
Сколько float-значений в каждом входном векторе (размерность embedding).
Обычно это 256, 512, 768, 1024 и т.д.
head_size : int
Сколько компонент будет у каждой головы внимания.
Важно: num_heads * head_size должно ровно совпадать с emb_size!
Обычно head_size = emb_size // num_heads.
max_seq_len : int
Максимально допустимая длина последовательности для attention/маски/генерации.
Определяет размер буферов для causal mask.
rope : RoPE, по умолчанию None
Объект Rotary Positional Encoding (если хотите привнести продвинутое позиционное кодирование в attention).
Не обязателен, но нужен для современных LLM (Llama, Mistral и пр.).
dropout : float, по умолчанию 0.1
Величина dropout (регуляризации) — помогает борьбе с переобучением. Чем больше, тем сильнее регуляризация.
Внутри конструктора происходит:
-------------------------------
- Создаются три линейных слоя для Q, K, V (“где смотреть” и “что вытаскивать” в attention).
- Генерируется нижнетреугольная causal-маска (запрещает видеть будущее для автогерерации).
- Создаётся финальный линейный слой для склейки всех голов в одно пространство emb_size.
- Вводится dropout (случайное зануление, чтобы не было сильной зависимости внимания к отдельным "плейсам").
Пример:
-------
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
"""
super().__init__()
self._heads = nn.ModuleList([
HeadAttention(
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
) for _ in range(num_heads)
])
self._num_heads = num_heads
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._q = nn.Linear(emb_size, num_heads * head_size)
self._k = nn.Linear(emb_size, num_heads * head_size)
self._v = nn.Linear(emb_size, num_heads * head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
self._layer = nn.Linear(head_size * num_heads, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None):
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
use_cache: bool = True,
cache: list = None,
):
"""
Прямой проход (forward):
Для каждого токена оценивает "важность" остальных токенов сразу через несколько attention-блоков.
Основной шаг \"многоголового внимания\": находит взаимосвязи между токенами
в последовательности сразу из нескольких “ракурсов” (attention heads).
Подробное описание преобразований тензоров:
1. Входной тензор [batch_size, seq_len, emb_size] разделяется на N голов:
- Каждая голова получает тензор [batch_size, seq_len, head_size]
2. Каждая голова вычисляет attention:
- Вход: [batch_size, seq_len, head_size]
- Выход: [batch_size, seq_len, head_size]
3. Конкатенация результатов:
- Объединенный выход: [batch_size, seq_len, num_heads * head_size]
4. Линейная проекция:
- Выход: [batch_size, seq_len, emb_size]
5. Применение dropout
Что делает этот метод:
----------------------
- Для каждого токена сравнивает его с остальными во входной последовательности.
- Делает это одновременно через несколько attention heads (каждая head видит текст по-своему).
- Итоговое “внимание” — это взвешенная сумма других токенов (контекста) для каждого токена.
- Можно использовать кэш для генерации длинных последовательностей по одному токену (ускоряет инференс).
Args:
x (Tensor[float]): [batch, seq_len, emb_size] — вход
mask (Optional[Tensor[bool]]): маска позиции [seq_len, seq_len]
use_cache (bool): использовать ли key-value кэш (для генерации)
cache (list): предыдущие значения KV для ускорения
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch, seq_len, emb_size].
Это ваши входные эмбеддинги (обычно после token + positional embedding).
mask : torch.Tensor, опционально
Матрица формы [seq_len, seq_len], задающая “разрешения” — кто может смотреть на кого (например, causal mask).
Если не указана — используется внутренняя маска (например, для autoregressive генерации).
use_cache : bool, по умолчанию True
Нужно ли использовать кэш для KV attention (важно для ускорения генерации по одному токену).
cache : list, опционально
Предыдущий кэш Key/Value — для генерации текста по частям.
Returns:
out (Tensor[float]): [batch, seq_len, emb_size] — результат MHA
kv_caches (list): списки новых KV-кэшей (если используется)
Возвращает:
-----------
- output: torch.Tensor формы [batch, seq_len, emb_size] — результат применения multi-head attention.
- kv_caches: список новых KV для кэширования при генерации (или None).
Типичный паттерн:
Вход: [batch, seq, emb] → N голов [batch, seq, head_size] →
→ concat [batch, seq, N*head_size] → проекция → dropout
Пример преобразований для emb_size=512, num_heads=8:
Вход: [4, 100, 512]
-> Каждая голова: [4, 100, 64]
-> После внимания: 8 x [4, 100, 64]
-> Конкатенация: [4, 100, 512]
-> Проекция: [4, 100, 512]
-> Dropout: [4, 100, 512]
Важно:
-------
- Shape входа всегда [batch, seq_len, emb_size], выход тот же.
- При seq_len > max_seq_len выбросит ошибку (безопасно для контроля переполнения буферов).
- При использовании use_cache=True кешируется только последние токены (актуально для LLM).
Пример:
>>> out, caches = mha(x)
>>> out.shape # [batch, seq_len, emb_size]
-------
>>> attn = MultiHeadAttention(num_heads=8, emb_size=256, head_size=32, max_seq_len=1024)
>>> x = torch.randn(2, 100, 256)
>>> y, kv_cache = attn(x)
>>> print(y.shape) # torch.Size([2, 100, 256])
"""
# 1. Вычисляем attention для каждой головы
attention_results = []
for i, head in enumerate(self._heads):
head_cache = cache[i] if cache is not None else None
result = head(x, use_cache=use_cache, cache=head_cache)
attention_results.append(result)
batch_size, seq_len, emb_size = x.shape
outputs, caches = zip(*attention_results)
attention_outputs = list(outputs)
kv_caches = list(caches)
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
# 2. Объединяем результаты всех голов
concatenated_attention = torch.cat(attention_outputs, dim=-1)
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# Шаг 2: Изменение формы для multi-head
# [batch_size, seq_len, num_heads * head_size]
# -> [batch_size, seq_len, num_heads, head_size]
q = q.reshape(batch_size, seq_len, self._num_heads, self._head_size)
k = k.reshape(batch_size, seq_len, self._num_heads, self._head_size)
v = v.reshape(batch_size, seq_len, self._num_heads, self._head_size)
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
start_pos = 0
if cache is not None:
k_cache, v_cache = cache
cache_len = k_cache.shape[2]
start_pos = cache_len
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
if self._rope is not None:
# ✅ Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q, start_pos=start_pos) # [B, T, hs]
k = self._rope(k, start_pos=start_pos) # [B, T, hs]
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
# 5. Кэширование (для autoregressive generation)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
v = torch.cat([v_cache, v], dim=2)
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
# И разделить все значения в матрице внимания на корень из head_size.
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
if cache is None:
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len], float("-inf")
)
# Применить к матрице внимания (построчно) функцию Softmax.
weights = F.softmax(scores, dim=-1)
# Перемножим матрицу внимания и матрицу значения.
x_out = weights @ v # [B, T, hs]
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
# Transpose обратно и concatenate heads
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
x_out = x_out.contiguous() # Важно для reshape!
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
#concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_heads * self._head_size)
# Пропустите получившийся тензор через последний линейный слой.
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
@@ -124,6 +258,6 @@ class MultiHeadAttention(nn.Module):
final_output = self._dropout(projected_output)
if use_cache is True:
return (final_output, kv_caches)
return (final_output, (k, v))
else:
return (final_output, None)

View File

@@ -0,0 +1,252 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.rope import RoPE
class MultiQueryAttention(nn.Module):
"""
Multi-Query Attention (MQA) — быстрый и экономичный вариант self-attention для LLM.
Назначение:
-----------
Класс реализует механизм внимания (self-attention), в котором для всех Query-голов используются одни и те же Key и Value.
В классическом MultiHeadAttention (MHA) на каждый Query используется свой Key/Value. В MQA набор Key/Value общий для всех голов,
что снижает требования к памяти и ускоряет работу, что особенно важно для больших LLM на inference.
Теоретическое преимущество:
--------------------------
- Существенно экономит память на матрицы Key и Value: количество KV-голов обычно в 48 раз меньше, чем число Query-голов.
- Позволяет достигать скорости почти обычной MHA при минимальной потере точности (см. Llama, Mistral).
- Является стандартом де-факто для deployment и inference современных LLM.
Архитектурная схема:
--------------------
- Для каждого токена во входе вычисляются Q_h (отдельные для каждой Query-головы), но K и V — общие для всех.
- Attention внутри каждой головы формируется через матричный продукт соответствующей Q_h и общего K.
- Выходные вектора голов конкатенируются и проецируются обратно в emb_size.
Формулы:
--------
Q = Wq·x, K = Wk·x, V = Wv·x
(Wq — отдельные для всех Query, Wk/Wv — общие для всех голов)
Attention_h(x) = softmax(Q_h·K^T / sqrt(d_k))·V
Output = Concat_h([Attention_h(x)])·W_o
Аргументы конструктора:
-----------------------
emb_size : int
Размерность скрытого пространства (hidden size, embedding dim).
num_heads : int
Число Query-голов (обычно 832 в LLM).
kv_heads : int
Число Key/Value-голов (обычно 1, 2, 4, 8).
head_size : int, optional
Размерность одной головы (обычно emb_size // num_heads).
dropout : float, optional
Вероятность Dropout для регуляризации внимания.
Пример использования:
---------------------
>>> mqa = MultiQueryAttention(emb_size=512, num_heads=8, kv_heads=1)
>>> x = torch.randn(2, 16, 512)
>>> mask = torch.ones(2, 16, 16)
>>> out = mqa(x, mask)
>>> print(out.shape) # torch.Size([2, 16, 512])
Литература и статьи:
--------------------
- Shazeer, N., “Fast Transformer Decoding: One Write-Head Is All You Need” (MQA): https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
- Mistral: https://arxiv.org/abs/2310.06825
- PaLM/PaLM2, Mixtral, ChatGLM: практическое описание MQA.
"""
def __init__(
self,
num_q_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE = None,
dropout: float = 0.1,
):
"""
Конструктор MultiQueryAttention.
Инициализирует все слои и буферы для реализации Multi-Query Attention с общими K/V-головами и индивидуальными Q-головами.
Позволяет существенно ускорять инференс и экономить память при работе с большими языковыми моделями.
Аргументы:
----------
num_q_heads : int
Число query-голов (обычно совпадает с количеством attention heads в модели).
Определяет количество параллельных subspace для запроса.
emb_size : int
Размер скрытого пространства embedding (input/output размерность attention слоя).
head_size : int
Размерность одной attention-головы.
Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина последовательности (нужна для построения треугольной маски causal attention).
rope : RoPE, optional
Модуль для rotary positional encoding (позиционный энкодер, улучшает обобщающую способность attention).
Если None, positional encoding не применяется.
dropout : float, по умолчанию 0.1
Вероятность dropout для выходного слоя attention (регуляризация).
Внутри:
-------
- Насчитывает отдельные весовые слои для Q, общие для всех голов K/V.
- Строит causal маску для автогрессивной генерации.
- (Опционально) использует RoPE для позиционного кодирования.
- Dropout применяется после финального projection.
Пример:
-------
>>> mqa = MultiQueryAttention(emb_size=256, num_q_heads=8, head_size=32, max_seq_len=2048, rope=None, dropout=0.1)
"""
super().__init__()
self._num_q_heads = num_q_heads
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._q = nn.Linear(emb_size, num_q_heads * head_size)
self._k = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
self._layer = nn.Linear(num_q_heads * head_size, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
use_cache: bool = True,
cache: list = None,
):
"""
Прямой проход (forward) через слой MultiQueryAttention.
Реализует multi-query self-attention для входных последовательностей с оптимизацией памяти за счёт общих K/V-голов для всех Query.
Поддерживает работу с rotary positional encoding (RoPE), каузальной маской и кэшированием для ускорения генерации.
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size] — скрытые состояния после предыдущего слоя или эмбеддинга.
mask : torch.Tensor, optional
Необязательная маска внимания (например, для padding или custom-маскировки). По умолчанию используется встроенная causal mask.
use_cache : bool, по умолчанию True
Если True, возвращает кэш ключей/значений (для autoregressive inference/generation).
cache : list, optional
(K_cache, V_cache) — предварительный кэш KV (для ускоренного инференса). Если None, кэш не используется/создаётся заново.
Возвращает:
-----------
если use_cache == True:
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- attention_out: [batch_size, seq_len, emb_size] — результат attention после проекции и dropout.
- (K, V): кэшированные ключи и значения (использовать для последующих forward'ов в autoregressive генерации)
если use_cache == False:
Tuple[torch.Tensor, None]
Математические шаги:
--------------------
1. Q = Wq·x; K = Wk·x; V = Wv·x # Q: индивидуальные для каждой головы, K/V — общие
2. [optional] Rotary positional encoding применяется к Q и K
3. (optional) concat c k/v cache (for autoregressive inference)
4. attention_scores = softmax(Q·K^T / sqrt(head_size), mask)
5. attention_out = attention_scores·V
6. heads сливаются и проецируются в emb_size; применяется dropout.
Пример:
-------
>>> out, cache = mqa(x, mask=attn_mask, use_cache=True, cache=prev_cache)
>>> print(out.shape) # torch.Size([batch_size, seq_len, emb_size])
Примечания:
-----------
- Для каузального режима используется треугольная маска (по умолчанию).
- Для генерации текста с cache передавайте кэш от предыдущих токенов — это ускоряет autoregressive inference.
- Внимание! Тензоры внутри cache должны иметь форму [batch, heads, seq_len, head_size].
"""
batch_size, seq_len, emb_size = x.shape
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# Шаг 2: Изменение формы для multi-head
# [batch_size, seq_len, num_heads * head_size]
# -> [batch_size, seq_len, num_heads, head_size]
q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)
k = k.reshape(batch_size, seq_len, 1, self._head_size)
v = v.reshape(batch_size, seq_len, 1, self._head_size)
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
if self._rope is not None:
# Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
# 5. Кэширование (для autoregressive generation)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
v = torch.cat([v_cache, v], dim=2)
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
# И разделить все значения в матрице внимания на корень из head_size.
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
if cache is None:
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len], float("-inf")
)
# Применить к матрице внимания (построчно) функцию Softmax.
weights = F.softmax(scores, dim=-1)
# Перемножим матрицу внимания и матрицу значения.
x_out = weights @ v # [B, T, hs]
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
# Transpose обратно и concatenate heads
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
x_out = x_out.contiguous() # Важно для reshape!
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)
# Пропустите получившийся тензор через последний линейный слой.
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
# 4. Применяем dropout для регуляризации
final_output = self._dropout(projected_output)
if use_cache is True:
return (final_output, (k, v))
else:
return (final_output, None)

View File

@@ -1,66 +1,105 @@
import torch
from torch import nn, Tensor
class PositionalEmbeddings(nn.Module):
"""
Обучаемые позиционные эмбеддинги (learnable positional embeddings).
PositionalEmbeddings — классические позиционные эмбеддинги для трансформеров (absolute sinusoidal or learned).
Позиционные эмбеддинги используются в нейросетях для передачи информации
о позиции элементов в последовательности (например, в Transformer).
Назначение:
-----------
- Добавляет или конкатенирует форму позиционной информации к каждому входному токену (since Transformer cannot distinguish positions otherwise).
- Используется во всех \"ранних\" трансформерах (GPT, BERT, T5), чаще всего в виде learnable или синусоидальных embeddings.
Научная суть:
- Трансформеры не используют рекуррентность, а значит сами по себе не различают порядок слов.
- Позиционные эмбеддинги добавляются к токеновым, чтобы сеть понимала, в каком месте последовательности находится каждый токен.
- Обычно реализуются как отдельная матрица (nn.Embedding), которая обучается вместе с моделью (это learnable вариант, как в GPT и BERT).
Архитектурные варианты:
-----------------------
- Learnable positional embeddings (как в GPT-2): обычный nn.Embedding инициализируется случайно, и веса учатся вместе с моделью.
- Sinusoidal positional encoding (как в оригинальном Transformer): не имеет параметров, а создаётся по заданной формуле sin/cos(ω*x).
Args:
max_seq_len (int): максимальная длина последовательности
emb_size (int): размер вектора позиции
Принцип работы:
---------------
- Для каждой позиции t заполняется вектор emb_size длиной по формуле (или выбирается из weight matrix).
- Эти вектора можно либо складывать с токеновыми эмбеддингами, либо конкатенировать.
- Позволяет attention-механизму \"понимать\" порядок токенов/слов в последовательности.
Пример использования:
>>> pos_encoder = PositionalEmbeddings(max_seq_len=100, emb_size=256)
>>> # Получить эмбеддинги для последовательности из 10 элементов
>>> embeddings = pos_encoder(10) # Tensor shape: [10, 256]
>>> # Использование в модели
>>> class MyModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.pos_emb = PositionalEmbeddings(100, 256)
... def forward(self, x):
... pos = self.pos_emb(x.size(1))
... return x + pos # Добавляем позиционную информацию
Формулы (Or: Vaswani et al., 2017):
------------------------------------
PE(pos, 2i) = sin(pos / 10000^{2i/d})
PE(pos, 2i+1) = cos(pos / 10000^{2i/d})
где d = emb_size, pos = позиция (int), i = индекс пары компонент.
Аргументы конструктора:
-----------------------
max_seq_len: int — максимально поддерживаемая длина последовательности
emb_size: int — размер возвращаемого positional vector для каждой позиции
(иногда выбирается вариант — learnable или фиксация через sin/cos)
Пример:
-------
>>> pos = PositionalEmbeddings(max_seq_len=1024, emb_size=256)
>>> p = pos(32) # Получить positional embeddings для 32 позиций
>>> p.shape # torch.Size([32, 256])
>>> token_emb = ... # [batch, seq_len, emb_size]
>>> encoded = token_emb + p.unsqueeze(0) # Broadcast add
References:
-----------
- Vaswani et al., \"Attention is All You Need\", 2017: https://arxiv.org/abs/1706.03762
- GPT-2 implementation: https://github.com/openai/gpt-2
- Почему positional encoding важен: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
"""
def __init__(self, max_seq_len: int, emb_size: int):
"""
Инициализация позиционного энкодера.
Аргументы:
----------
max_seq_len : int
Максимальная длина последовательности (builds buffer for sin/cos or embedding)
emb_size : int
Длина позиционного вектора
Внутри:
-------
- Если используется learned embedding: создаётся nn.Embedding (можно легко менять в будущем).
- Если fixed (sin/cos): вычисляется и хранится буфер (max_seq_len, emb_size).
"""
super().__init__()
self.max_seq_len = max_seq_len
self.emb_size = emb_size
self.embedding = nn.Embedding(
num_embeddings=max_seq_len,
embedding_dim=emb_size
num_embeddings=max_seq_len, embedding_dim=emb_size
)
def forward(self, seq_len: int, start_pos: int = 0) -> Tensor:
"""
Возвращает позиционные эмбеддинги для заданной длины последовательности.
Получить positional embeddings для последовательности длиной seq_len.
Args:
seq_len (int): Длина последовательности (1 <= seq_len <= max_seq_len)
Аргументы:
----------
seq_len : int
Сколько позиций сгенерировать (обычно == входная длина x)
start_pos : int, по умолчанию 0
Возможность выдать positional embeddings \"с середины\" (для autoregressive генерации)
Returns:
Tensor: Тензор позиционных эмбеддингов формы [seq_len, emb_size]
Raises:
IndexError: Если seq_len выходит за допустимые границы
Возвращает:
-----------
torch.Tensor — positional embeddings формы [seq_len, emb_size]
Пример:
>>> pos_encoder = PositionalEmbeddings(100, 64)
>>> emb = pos_encoder(10) # Тензор 10x64
-------
>>> pos = PositionalEmbeddings(512, 128)
>>> p = pos(10) # [10, 128]
"""
if seq_len < 1 or seq_len > self.max_seq_len:
raise IndexError(f"Длина {seq_len} должна быть от 1 до {self.max_seq_len}")
if start_pos == 0:
positions = torch.arange(seq_len, device=self.embedding.weight.device)
else:
positions = torch.arange(start=start_pos, end=start_pos + seq_len, device=self.embedding.weight.device)
positions = torch.arange(
start=start_pos,
end=start_pos + seq_len,
device=self.embedding.weight.device,
)
return self.embedding(positions)

View File

@@ -24,35 +24,63 @@ from typing import Optional
class RMSNorm(nn.Module):
"""
RMS Normalization (Root Mean Square Layer Normalization).
RMSNorm (Root Mean Square Layer Normalization) — простая и эффективная альтернатива LayerNorm.
Нормализует входные данные по последнему измерению используя среднеквадратичное
значение вместо среднего, как в стандартном LayerNorm.
Назначение:
-----------
- Нормализует входной тензор по последнему измерению только с помощью RMS (root mean square), без вычитания среднего.
- Используется в LLaMA, PaLM и других крупных языковых моделях для лучшей стабильности и ускорения обучения.
- В отличие от LayerNorm, не центрирует значения, что особенно полезно для автогерессивных трансформеров с residual-связями.
Научная суть:
- Упрощенный вариант LayerNorm без вычисления среднего, только деление на rms.
- Лучшая численная стабильность на больших моделях, меньше вычислений.
- Применяется в LLaMA, PaLM и др.
Мотивация и математика:
-----------------------
- Формула для одного слоя и вектора x:
rms = sqrt( mean( x ** 2 ) + eps )
out = w * ( x / rms )
где w — learnable scale, eps — небольшая константа для численной устойчивости.
- Нет смещения/вычитания среднего — сигнал сохраняет абсолютные значения, меньше “искажает” автоагрегатные значения на накопленных резидуалах.
Формула:
RMSNorm(x) = (x / sqrt(mean(x²) + eps)) * w (w — обучаемый вектор)
Аргументы конструктора:
-----------------------
dim : int
Размер последнего нормализуемого измерения (обычно совпадает с размером embedding/final head).
eps : float, default=1e-6
Малое значение для устойчивости (additive epsilon).
Args:
dim (int): размер последнего измерения (обычно emb_size)
eps (float): для численной устойчивости
Особенности:
------------
- Нет батч-нормализации, нет зависимости от размера батча.
- Отлично подходит для больших моделей и автогерессии — меньше шуму от residual.
Пример использования:
---------------------
>>> norm = RMSNorm(emb_size=256)
>>> x = torch.randn(4, 10, 256)
>>> out = norm(x) # возвращает tensor той же формы
References:
-----------
- Zhang & Sennrich, "Root Mean Square Layer Normalization", 2019: https://arxiv.org/abs/1910.07467
- Применение в LLaMA: https://arxiv.org/abs/2302.13971
- HuggingFace implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
Пример:
>>> norm = RMSNorm(emb_size)
>>> out = norm(x)
"""
def __init__(self, dim: int, eps: float = 1e-6):
"""
Инициализация RMSNorm слоя.
Инициализация RMSNorm.
Args:
dim: Размерность нормализуемого измерения
eps: Малое значение для численной стабильности (по умолчанию 1e-6)
-----
dim : int
Последнее нормализуемое измерение (обычно размерность embedding или hidden).
eps : float
Малое значение для устойчивости (по умолчанию 1e-6).
Внутри:
-------
- Создаётся обучаемый scale weight w для каждой компоненты dim.
- Сохраняется параметр eps для добавления к RMS.
"""
super().__init__()
self._eps = eps
@@ -60,16 +88,28 @@ class RMSNorm(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Прямой проход через RMSNorm слой.
Прямой проход через RMSNorm.
Args:
x: Входной тензор формы [..., dim]
-----
x : torch.Tensor
Входной тензор любого shape с последней размерностью dim.
Returns:
Нормализованный тензор той же формы, что и входной
--------
torch.Tensor — тот же shape, что и вход x, но нормализованный по RMS на последнем измерении.
Алгоритм:
---------
- Вычислить rms = sqrt( mean( x**2, dim=-1, keepdim=True ) + eps )
- Поделить x на rms
- Помасштабировать обучаемым весом w
Пример:
-------
>>> norm = RMSNorm(256)
>>> out = norm(torch.randn(2, 10, 256))
Формула:
output = w * (x / sqrt(mean(x²) + eps))
"""
# Вычисление RMS (Root Mean Square) по последнему измерению
rms = (x.pow(2).mean(-1, keepdim=True) + self._eps) ** 0.5
@@ -80,4 +120,4 @@ class RMSNorm(nn.Module):
def extra_repr(self) -> str:
"""Строковое представление для отладки."""
return f'dim={self._w.shape[0]}, eps={self._eps}'
return f"dim={self._w.shape[0]}, eps={self._eps}"

View File

@@ -1,21 +1,51 @@
"""
Rotary Positional Embeddings (RoPE) - ротационные позиционные эмбеддинги.
Rotary Positional Embeddings (RoPE)
===================================
Реализация ротационного позиционного кодирования, которое кодирует позиционную
информацию через вращение векторов запросов и ключей в комплексном пространстве.
Что такое RoPE?
----------------
RoPE — это способ "вписать" информацию о позиции токенов в скрытые вектора модели трансформера.
Вместо простого сложения с абсолютным positional embedding, RoPE использует вращения векторов (как поворот стрелки на круге) внутри каждого attention head. Каждый элемент пары (вектор четного и нечетного индекса) поворачивается на угол, зависящий от позиции токена.
Научная статья: "RoFormer: Enhanced Transformer with Rotary Position Embedding"
https://arxiv.org/abs/2104.09864
Зачем это?
-----------
- RoPE реализует **относительное позиционное кодирование**: модель может сравнивать "расстояния" между токенами, а не просто помнить положение.
- Такое кодирование **улучшает генерацию длинных последовательностей** и перенос модели на тексты большей длины, чем были в обучении.
- Форма векторов и длина (норма) НЕ искажаются.
Математическая основа:
Для позиции m и измерения i:
θ_i = base^(-2i/d)
q'_m = q_m * cos(mθ_i) + rotate(q_m) * sin(mθ_i)
Как это работает? (главная формула)
-------------------------------------
Для каждой позиции m и пары компонент (2i, 2i+1) внутри head применяются:
θ_i = base^(-2i / d)
q'_{m,2i} = q_{m,2i} * cos(m * θ_i) - q_{m,2i+1} * sin(m * θ_i)
q'_{m,2i+1} = q_{m,2i+1} * cos(m * θ_i) + q_{m,2i} * sin(m * θ_i)
где d — размерность "головы" attention (head_size), base обычно 10_000.
То есть, берём каждый "вектор" (в рамках head), делим на четные/нечетные части и поворачиваем их на уникальный угол, связанный с позицией/частотой.
Архитектурные детали:
---------------------
- Ваш тензор должен быть строго 4-мерным: [batch, num_heads, seq_len, head_size].
- Размер head_size должен быть чётным!
- RoPE применяется отдельно к **Q** и **K** в механизме внимания (но не к V).
Где об этом читать:
-------------------
- RoFormer: Enhanced Transformer with Rotary Position Embedding
https://arxiv.org/abs/2104.09864
- Llama: Open and Efficient Foundation Language Models
https://arxiv.org/abs/2302.13971
- Визуализация позиционных кодировок:
https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
Пример использования:
---------------------
>>> rope = RoPE(head_size=64, max_seq_len=2048)
>>> x = torch.randn(2, 8, 128, 64) # [batch, num_heads, seq_len, head_size]
>>> x_enc = rope(x) # здесь вектор x обогатится позиционной информацией
Преимущества:
- Относительное позиционное кодирование
- Лучшая экстраполяция на длинные последовательности
- Сохранение нормы векторов
"""
import torch
@@ -25,32 +55,72 @@ from typing import Optional
class RoPE(nn.Module):
"""
Rotary Positional Embeddings (RoPE) для механизма внимания.
Реализация RoPE (Rotary Positional Embeddings) для self-attention в трансформерах.
Кодирует позиционную информацию через вращение векторов запросов и ключей
в многомерном пространстве с использованием синусов и косинусов.
Этот слой добавляет позиционную информацию к векторам внимания (Q, K) —
не с помощью простого сложения с positional embedding, а с помощью математического
вращения (как если бы вы крутили стрелку на круге) для каждой пары компонент
(even/odd) в каждом attention head.
Args:
head_size: Размерность головы внимания (должен быть четным)
max_seq_len: Максимальная длина последовательности
base: Базовое значение для вычисления частот (по умолчанию 10000)
Формула (для каждого токена и каждой пары компонент внутри head):
θ_i = base^(-2i / d)
out_{m,2i} = x_{m,2i} * cos(m * θ_i) - x_{m,2i+1} * sin(m * θ_i)
out_{m,2i+1} = x_{m,2i+1} * cos(m * θ_i) + x_{m,2i} * sin(m * θ_i)
где d — head_size, base обычно 10_000, степень i по head axis.
Attributes:
cos_matrix: Буферизованная матрица косинусов формы [max_seq_len, head_size//2]
sin_matrix: Буферизованная матрица синусов формы [max_seq_len, head_size//2]
Какие входы принимает:
----------------------
- x: обязательно размерности [batch, num_heads, seq_len, head_size]!
- head_size (размер внимания) должен быть чётным.
- start_pos: опционально, позволяет сдвигать позиционный offset для генерации с кэшем.
Что возвращает:
---------------
- Тот же тензор (x), только со встроенной позиционной информацией (“повёрнутый” RoPE-кодировкой).
- Форма и тип выходного тензора не меняются.
Где используется:
-----------------
- В любых современных LLM (Llama, Mistral, GPT-NeoX и др.) для повышения устойчивости и generalization transformer's attention.
Пример использования:
---------------------
>>> rope = RoPE(head_size=64, max_seq_len=2048)
>>> x = torch.randn(2, 8, 128, 64) # (batch, num_heads, seq_len, head_size)
>>> x_encoded = rope(x)
Подробнее про математику и примеры с визуализацией:
---------------------------------------------------
- RoFormer: https://arxiv.org/abs/2104.09864
- Llama: https://arxiv.org/abs/2302.13971
- Демонстрация наглядно: https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
"""
def __init__(self, head_size: int, max_seq_len: int, base: int = 10_000):
"""
Инициализация RoPE эмбеддингов.
Инициализация объекта RoPE — настраивает и предвычисляет все необходимые
параметры для ротационного позиционного кодирования.
Args:
head_size: Размерность головы внимания (должен быть четным)
max_seq_len: Максимальная поддерживаемая длина последовательности
base: Базовое значение для вычисления частот (типично 10000)
Аргументы:
----------
head_size : int
Размер одного attention head (последнего измерения вектора) — сколько компонент
(float-значений) отвечает за одну "голову". Должен быть ЧЁТНЫМ числом, иначе RoPE не применим.
Обычно head_size = embed_dim // num_heads.
max_seq_len : int
Максимальная длина последовательности, которую RoPE сможет обработать.
Если ваш текст длиннее этого числа — будет ошибка! Например, для GPT2 обычно 1024, у LLaMA — до 4096.
Это число определяет размер внутренних буферов cos/sin.
base : int, по умолчанию 10_000
База для вычисления частот вращения (θ_i) для каждой компоненты.
В оригинальных статьях почти всегда используют base=10000.
Менять этот параметр не нужно, если вы не исследуете математические детали.
Raises:
AssertionError: Если head_size не четный
Что происходит внутри:
----------------------
- Проверяется чётность head_size.
- Для каждого возможного положения в пределах max_seq_len и каждой пары component высчитываются уникальные cos/sin значения (матрицы частот).
- Эти матрицы используются далее для быстрого наложения позиционного "вращения" токенов внутри attention.
"""
super().__init__()
assert head_size % 2 == 0, "head_size должен быть четным"
@@ -65,33 +135,56 @@ class RoPE(nn.Module):
freq_matrix = positions.unsqueeze(1) * freqs.unsqueeze(0)
# Предвычисление матриц косинусов и синусов
self.register_buffer('cos_matrix', torch.cos(freq_matrix))
self.register_buffer('sin_matrix', torch.sin(freq_matrix))
self.register_buffer("cos_matrix", torch.cos(freq_matrix))
self.register_buffer("sin_matrix", torch.sin(freq_matrix))
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(self, x: torch.Tensor, start_pos: int = 0) -> torch.Tensor:
"""
Применение ротационного позиционного кодирования к входному тензору.
Применяет ротационное позиционное кодирование (RoPE) к входному тензору.
Args:
x: Входной тензор формы [batch_size, seq_len, head_size]
Что делает эта функция:
-----------------------
Для каждого токена в последовательности внутри каждого attention head
"поворачивает" его вектор в подпространстве (even/odd пар) на свой уникальный угол,
зависящий от позиции токена. Это позволяет attention "понимать расстояния" между токенами.
Returns:
Тензор с примененным RoPE формы [batch_size, seq_len, head_size]
Аргументы:
----------
x : torch.Tensor
Входной тензор строго формы [batch, num_heads, seq_len, head_size].
Это обычно либо Q, либо K из механизма внимания.
start_pos : int, по умолчанию 0
Сдвиг начала позиции (нужно при генерации с кэшем, почти всегда оставить 0 если не пишете автогенератор).
Алгоритм:
1. Разделение векторов на четные и нечетные компоненты
2. Применение вращения через синусы и косинусы
3. Объединение компонент обратно
Возвращает:
-----------
torch.Tensor с теми же формой и типом, что и x, но уже с наложенным позиционным кодированием.
Важно:
-------
- Если передан тензор не 4D, будет выброшено исключение!
- Не изменяет значения "на месте", всегда возвращает новый тензор.
Пример:
-------
>>> rope = RoPE(head_size=64, max_seq_len=1024)
>>> q = torch.randn(2, 8, 32, 64) # batch, num_heads, seq_len, head_size
>>> q_rope = rope(q)
"""
seq_len = x.size(1)
assert x.ndim == 4, "RoPE поддерживает только 4D-вход [batch, num_heads, seq_len, head_size]"
batch_size, num_heads, seq_len, head_size = x.shape
# Берем нужную часть матриц и приводим к типу x
cos = self.cos_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[:seq_len].to(x.dtype) # [seq_len, head_size//2]
cos = self.cos_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
sin = self.sin_matrix[start_pos:start_pos+seq_len].to(x.dtype) # [seq_len, head_size//2]
# Разделяем на четные и нечетные компоненты
x_even = x[:, :, 0::2] # [batch_size, seq_len, head_size//2]
x_odd = x[:, :, 1::2] # [batch_size, seq_len, head_size//2]
# Явное изменение формы для broadcasting
cos = cos.reshape(1, 1, seq_len, head_size // 2)
sin = sin.reshape(1, 1, seq_len, head_size // 2)
# Разделяем на четные и нечетные компоненты по ПОСЛЕДНЕМУ измерению
x_even = x[..., 0::2] # [batch_size, num_heads, seq_len, head_size//2]
x_odd = x[..., 1::2] # [batch_size, num_heads, seq_len, head_size//2]
# Применяем поворот: q' = q * cos(mθ) + rotate(q) * sin(mθ)
x_rotated_even = x_even * cos - x_odd * sin

View File

@@ -1,19 +1,70 @@
import torch
from torch import nn
class SiLU(nn.Module):
"""
SiLU (Swish) — современная активационная функция для нейросетей.
SiLU (Sigmoid Linear Unit, также известная как Swish) — современная функция активации для нейросетей и LLM.
Научная суть:
- Формула: $SiLU(x) = x * \sigm(x)$, где $\sigm(x)$ — сигмоида.
- Более гладкая альтернатива ReLU, улучшает поток градиентов в глубоких сетях.
- Используется во многих «state-of-the-art» архитектурах (SwiGLU, PaLM, LLaMA).
- Также известна как Swish (Ramachandran et al, 2017).
Пример:
>>> act = SiLU()
>>> x = torch.tensor([-1.0, 0.0, 1.0])
>>> print(act(x))
Назначение:
-----------
- Формирует плавную нелинейную активацию: SiLU(x) = x * sigmoid(x).
- Активно используется во всех новых архитектурах для больших языковых моделей (PaLM, LLaMA, Mistral, GPT-4 и др.).
- Дает лучший поток градиентов по сравнению с ReLU, SELU, GELU в глубоких слоях — позволяет делать сети больше и глубже.
Мотивация и свойства:
---------------------
- SiLU объединяет свойства identity (для больших x) и ReLU (для отрицательных x, где есть затухание), но более плавно.
- Позволяет проходить отрицательным значениям, а не "обрубает" как ReLU.
- Better for optimization and training dynamics in deep LLMs, приводит к более богатым аппроксимациям.
Математическая формула:
-----------------------
SiLU(x) = x * sigmoid(x)
где sigmoid(x) = 1 / (1 + exp(-x))
Сравнение с другими активациями:
--------------------------------
- ReLU(x): max(0, x) — простая отсечка
- GELU(x): плавная вероятностная активация (используется в BERT/GPT-2)
- SiLU(x): плавная альтернатива, часто лучше в современных LLM
- Swish (Ramachandran et al., 2017) = SiLU
Args:
-----
Нет learnable параметров, чисто функциональная активация.
Пример использования:
---------------------
>>> silu = SiLU()
>>> x = torch.tensor([-2.0, 0.0, 2.0])
>>> print(silu(x)) # тензор с элементами [-0.2384, 0.0, 1.7616] (примерно)
References:
-----------
- Ramachandran et al., "Searching for Activation Functions", 2017: https://arxiv.org/abs/1710.05941
- LLaMA: https://arxiv.org/abs/2302.13971
- Swish в TensorFlow: https://arxiv.org/abs/1710.05941
- Сравнение всех актив. функций: https://paperswithcode.com/method/silu
"""
def forward(self, x: torch.Tensor):
"""
Применяет SiLU активацию ко всем компонентам тензора (x * sigmoid(x)).
Args:
-----
x : torch.Tensor
Входной тензор любой формы.
Returns:
--------
torch.Tensor — тензор той же формы, каждый элемент преобразован по формуле SiLU(x).
Пример:
-------
>>> silu = SiLU()
>>> x = torch.linspace(-3, 3, 7)
>>> y = silu(x)
"""
return torch.sigmoid(x) * x

View File

@@ -24,31 +24,55 @@ from .silu import SiLU
class SwiGLU(nn.Module):
"""
SwiGLU (Swish-Gated Linear Unit) — современная нелинейность для архитектур LLM (LLaMA, PaLM).
SwiGLU (Swish-Gated Linear Unit) — эффективная feed-forward нелинейность для трансформеров (LLAMA, PaLM, Mistral).
Реализация SwiGLU активационной функции.
Назначение:
-----------
- Улучшает классический блок FeedForward (FFN) в трансформерах за счёт \"gating\" (механизма управления информационным потоком).
- Использует нелинейность SiLU (Swish) вместо ReLU или GELU, повышая capacity блока.
- Является дефолтом во всех современных LLM, начиная с PaLM, LLaMA и Mistral.
Состоит из трех линейных слоев и активации SiLU:
1. Gate слой + SiLU активация
2. Up слой (линейное преобразование)
3. Element-wise multiplication gate и up
4. Down слой (линейная проекция)
Формула и математика:
---------------------
Пусть x — вход, then:
Научная суть:
- Сохраняет преимущества GLU (раздельные гейтом и телом) + мощность Swish/SiLU активации.
- Дает надежную гладкую активацию, хорошо работает на больших масштабах.
- Статья: "GLU Variants Improve Transformer" (Shazeer, 2020).
SwiGLU(x) = (SiLU(W_g x + b_g)) ⊙ (W_u x + b_u) W_d + b_d
Формула:
SwiGLU(x) = SiLU(W_g·x) * (W_u·x)
где SiLU(x) = x*sigma(x)
Типовая реализация (как здесь, по LLAMA/Mistral):
gate = SiLU(Linear_gate(x)) # фитчерный \"gate\"
up = Linear_up(x) # пропускная ветка
mult = gate * up # поэлементное умножение (контроль информации)
out = Linear_down(mult) # финальная проекция
out = Dropout(out) # регуляризация
Почему это работает:
-------------------
- Gating позволяет информации проходить \"частично\", динамически подавляя/усиливая сигналы в hidden-space.
- SiLU обеспечивает smooth градиенты (лучше для обучения LLM).
- В экспериментах (PaLM, LLAMA) SwiGLU consistently outperforms ReLU, GELU, обычные GLU.
Параметры конструктора:
-----------------------
emb_size: int
Размерность входного (и выходного) признакового пространства.
dropout: float
Dropout после final linear (обычно около 0.1).
Пример использования:
---------------------
>>> block = SwiGLU(emb_size=512, dropout=0.1)
>>> x = torch.randn(8, 16, 512)
>>> y = block(x)
>>> print(y.shape) # torch.Size([8, 16, 512])
References:
-----------
- Shazeer, \"GLU Variants Improve Transformer\", 2020: https://arxiv.org/abs/2002.05202
- PaLM: https://arxiv.org/abs/2204.02311 (Section 4.1)
- LLaMA: https://arxiv.org/abs/2302.13971
- Mistral: https://arxiv.org/abs/2310.06825
- HuggingFace discussion: https://huggingface.co/docs/transformers/main/en/model_doc/llama
Args:
emb_size (int): размер входов/выходов
dropout (float): после выходной проекции
Пример:
>>> ff = SwiGLU(emb_size=512, dropout=0.1)
>>> y = ff(torch.randn(2,10,512))
"""
def __init__(self, emb_size: int, dropout: float = 0.1):
@@ -68,34 +92,39 @@ class SwiGLU(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Прямой проход через SwiGLU слой.
Прямой проход через блок SwiGLU.
Args:
x: Входной тензор формы [batch_size, seq_len, emb_size]
-----
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size]
Returns:
Выходной тензор формы [batch_size, seq_len, emb_size]
--------
torch.Tensor той же формы
Алгоритм:
---------
1. gate = SiLU(linear_gate(x))
2. up = linear_up(x)
3. output = linear_down(gate up)
4. apply dropout
3. mult = gate * up # поэлементно
4. out = linear_down(mult)
5. out = dropout(out)
"""
# Gate ветвь: линейное преобразование + активация
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
# Up ветвь: линейное преобразование
up_out = self._up(x) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
# Element-wise multiplication (gating mechanism)
out = up_out * activation_out # поэлементное умножение!
out = up_out * activation_out # поэлементное умножение!
# Final projection and dropout
out = self._down(out) # [batch, seq, emb]
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)
def extra_repr(self) -> str:
"""Строковое представление для отладки."""
return f'emb_size={self._gate.in_features}, dropout={self._dropout.p}'
return f"emb_size={self._gate.in_features}, dropout={self._dropout.p}"

View File

@@ -2,68 +2,96 @@ import torch
from torch import nn
from torch import Tensor
class TokenEmbeddings(nn.Module):
"""
Токеновые эмбеддинги — обучаемые векторные представления для каждого токена словаря.
TokenEmbeddings — обучаемый слой эмбеддингов для токенов (слов, сабслов, байтов и т.д.) в трансформерах.
Преобразует целочисленные индексы токенов в обучаемые векторные представления фиксированного размера.
Обычно используется как первый слой в нейронных сетях для задач NLP.
Назначение:
-----------
- Преобразует каждый целочисленный индекс-токен из словаря (vocab) в обучаемый dense-вектор фиксированной длины.
- Это "входной слой" для любой нейросетевой языковой модели: позволяет работать с текстом как с матрицей чисел, а не с индексами/категориальными значениями.
- Обеспечивает возможность end-to-end обучения embedding-матрицы совместно с целью модели.
Научная суть:
- Первый шаг для любого NLP-модуля: вместо индекса токена подаём его dense-вектор.
- Эти вектора изучаются в процессе обучения и отражают скрытые взаимосвязи между токенами.
- Позволяют обрабатывать тексты как матрицу чисел, а не как символы или индексы.
- Аналог словарных эмбеддингов в word2vec, но обучаются энд-ту-энд с моделью.
Мотивация и особенности:
------------------------
- Каждый токен (индекс) получает свой learnable embedding (float-вектор).
- Размерность слоя: [vocab_size, emb_size] (матрица эмбеддингов).
- Веса эмбеддингов инициализируются случайно и обучаются вместе с остальной моделью.
- Аналог таблицы эмбеддингов в word2vec/fastText, но управляется end-to-end.
- Могут использоваться с любым токенизатором (BPE, SentencePiece, WordPiece и др.).
Формула:
--------
emb(x) = W[x], где W — матрица размера [vocab_size, emb_dim], x — индексы shape [batch, seq_len]
На выходе: тензор [batch, seq_len, emb_dim]
Args:
vocab_size (int): размер словаря (количество уникальных токенов)
emb_size (int): размерность эмбеддинга (длина вектора)
Примечание:
- Индексы должны быть в диапазоне [0, vocab_size-1]
- Эмбеддинги инициализируются случайно и обучаются в процессе тренировки модели
-----
vocab_size: int размер словаря/алфавита (количество уникальных токенов)
emb_size: int — размерность (длина) эмбеддинговых векторов (обычно 256/512/1024...)
Пример:
>>> emb = TokenEmbeddings(vocab_size=10000, emb_size=256)
>>> tokens = torch.tensor([[1, 2, 3]])
>>> vecs = emb(tokens)
>>> vecs.shape # torch.Size([1, 3, 256])
-------
>>> embedding = TokenEmbeddings(vocab_size=5000, emb_size=256)
>>> tokens = torch.tensor([[12, 47, 301], [6, 88, 413]])
>>> vecs = embedding(tokens)
>>> print(vecs.shape) # torch.Size([2, 3, 256])
References:
-----------
- Mikolov et al., "Efficient Estimation of Word Representations in Vector Space (word2vec)", 2013
- Vaswani et al., "Attention is All You Need", 2017: https://arxiv.org/abs/1706.03762
- BPE, SentencePiece overviews: https://huggingface.co/docs/transformers/tokenizer_summary
"""
def __init__(self, vocab_size: int, emb_size: int):
"""
Инициализация слоя эмбеддингов.
Args:
-----
vocab_size: int
Размер словаря (уникальных токенов/индексов).
emb_size: int
Длина эмбеддингового вектора для каждого токена.
Внутри:
-------
- Создаёт nn.Embedding с [vocab_size, emb_size] learnable весами.
"""
super().__init__()
self._embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=emb_size
num_embeddings=vocab_size, embedding_dim=emb_size
)
def forward(self, x: Tensor) -> Tensor:
"""
Получить эмбеддинги для входных токенов.
Args:
-----
x : torch.Tensor
Тензор shape [...], содержащий индексы токенов (каждое значение от 0 до vocab_size-1).
Returns:
--------
torch.Tensor — тензор обычной формы [..., emb_size] (на каждую позицию — свой embedding-вектор).
Пример:
-------
>>> embedding = TokenEmbeddings(vocab_size=100, emb_size=64)
>>> tokens = torch.tensor([[0, 99, 5]])
>>> vecs = embedding(tokens) # [1, 3, 64]
"""
return self._embedding(x)
@property
def num_embeddings(self) -> int:
"""Возвращает размер словаря"""
"""Возвращает размер словаря (количество уникальных токенов)."""
return self._embedding.num_embeddings
@property
def embedding_dim(self) -> int:
"""Возвращает размерность эмбеддингов"""
"""Возвращает размерность эмбеддингов (длина вектора каждого токена)."""
return self._embedding.embedding_dim
if __name__ == "__main__":
# Пример использования
embedding = TokenEmbeddings(vocab_size=100, emb_size=128)
# Создаем тензор с индексами в пределах vocab_size (0-99)
tensor = torch.tensor([
[11, 45, 76, 34],
[34, 67, 45, 54]
])
# Проверяем индексы
if (tensor >= 100).any():
raise ValueError("Some indices are out of vocabulary range (vocab_size=100)")
output = embedding(tensor)
print("Embeddings shape:", output.shape)
print(f"{output.shape} | {output.mean().item():.11f}") # Формат как в ТЗ

View File

View File

@@ -0,0 +1,120 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class StreamingTextDataset(Dataset):
"""
StreamingTextDataset — потоковый датасет для LLM/NLP на базе списка строк.
Назначение:
-----------
- Позволяет эффективно обрабатывать большие текстовые выборки, итерируя по заранее подготовленному списку строк.
- При итерации строки токенизируются на лету, превращаются в примеры фиксированной длины block_size (padding/truncation внутри класса).
- Поддерживает стандартный DataLoader PyTorch.
Ключевые особенности:
---------------------
- Не требует загрузки всей коллекции токенов в RAM: поддерживает работу с любым размером датасета, если список строк заранее подготовлен.
- Каждый пример (sample) формируется при обращении; не хранит массив батчей, не использует файлы внутри.
- Поддерживает любой токенизатор с методом encode (например, BPE, SentencePiece, HF Tokenizer).
- batch_size и параллелизм (num_workers) контролируются через DataLoader.
Аргументы конструктора:
-----------------------
texts: List[str] — список строк (предварительно загруженных обучающих примеров).
tokenizer: BaseTokenizer/Any — объект с методом encode(str, **kwargs) -> List[int].
block_size: int — длина одного выходного примера в токенах (padding/truncation если нужно).
Пример использования:
---------------------
>>> texts = open("wiki_sample.txt", encoding="utf-8").read().splitlines()
>>> ds = StreamingTextDataset(texts, tokenizer=tokenizer, block_size=512)
>>> loader = torch.utils.data.DataLoader(ds, batch_size=8)
>>> for batch in loader:
... print(batch['input_ids'].shape) # torch.Size([8, 512])
Особенности:
------------
- Проектирован для бесконечного стриминга текстовых данных из больших коллекций.
- При batch_size > 1 каждый batch формируется DataLoader-ом из yield'ов этого датасета.
- Не работает с файлами напрямую, только со строками (списком).
- Подходит для обучения LLM, тестирования, дообучения, оценки на больших потоковых данных.
References:
-----------
- PyTorch IterableDataset: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
- HuggingFace streaming datasets: https://huggingface.co/docs/datasets/stream
- Практика масштабного обучения LLM: https://github.com/karpathy/nanoGPT/issues/182
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация StreamingTextDataset из списка строк.
Аргументы:
texts (List[str]): Список строк — текстовые обучающие примеры; весь датасет должен помещаться в этот список.
tokenizer (Any): Токенизатор с методом encode(text, **kwargs) -> List[int].
block_size (int, по умолчанию 128): Желаемая длина токенизированного примера (padding/truncation внутри класса).
Особенности:
- Поддерживает итеративную загрузку, эффективен для больших текстовых выборок.
- Каждый пример автоматически дополняется или усекается до block_size.
- Не читает данные из файла/буфера, а только из заранее подготовленного списка строк.
Пример:
>>> ds = StreamingTextDataset(texts=all_lines, tokenizer=tokenizer, block_size=256)
>>> for ex in ds:
... print(ex['input_ids'].shape) # torch.Size([256])
"""
self.texts = texts
self.tokenizer = tokenizer
self.block_size = block_size
# Получаем pad_token_id из токенизатора
self.pad_token_id = getattr(tokenizer, "pad_token_id", 0)
def __len__(self):
"""
Возвращает количество доступных примеров в датасете.
Returns:
int: Число примеров (равно длине исходного списка строк).
"""
return len(self.texts)
def __getitem__(self, idx):
"""
Получить обработанный пример по индексу из потокового датасета.
Аргументы:
idx (int): Индекс примера в исходном списке строк.
Возвращает:
dict: Словарь с тензорами для обучения LLM:
- 'input_ids': torch.Tensor формы [block_size] — индексы токенов (padding/truncation выполнены)
- 'labels': torch.Tensor формы [block_size] — целевые метки (обычно совпадают с input_ids)
Пример:
>>> item = dataset[10]
>>> assert isinstance(item, dict)
>>> assert item['input_ids'].shape == (block_size,)
>>> assert 'labels' in item
"""
text = self.texts[idx]
# Токенизация на лету
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > self.block_size:
input_ids = input_ids[: self.block_size]
else:
input_ids = input_ids + [self.pad_token_id] * (
self.block_size - len(input_ids)
)
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,112 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class TextDataset(Dataset):
"""
TextDataset — простой датасет для подачи обучающих токенов LLM (batch-режим или по одному примеру).
Назначение:
-----------
- Хранит последовательности текста (каждую строку или пример) в виде списка строк.
- При обращении сам токенизирует строку в последовательность индексов с помощью заданного токенизатора.
- Каждый пример автоматически усекётся или будет дополнен до фиксированной длины block_size (padding — zeros).
Формат и аргументы конструктора:
-------------------------------
texts: List[str]
Список строк, каждая из которых рассматривается как отдельный обучающий пример.
tokenizer: любой объект с методом encode(str, **kwargs) → List[int]
Обеспечивает сопоставление строки списку токенов (например, BPE, HuggingFace, SentencePiece и др.).
block_size: int, по умолчанию 128
Желаемая длина выходной последовательности (padding/truncation внутри класса).
Особенности:
------------
- Класс не работает с файлами напрямую: данные передаются готовым списком строк.
- При недостаточной длине пример дополняется паддингом (нулём или другим токеном, зависит от реализации).
- Может возвращать dict с input_ids, labels и прочими ключами (см. реализацию в функции __getitem__).
Пример использования:
---------------------
>>> with open("dataset.txt", encoding="utf-8") as f:
... texts = f.read().splitlines()
>>> dataset = TextDataset(texts, tokenizer, block_size=256)
>>> from torch.utils.data import DataLoader
>>> loader = DataLoader(dataset, batch_size=4)
>>> for item in loader:
... # item['input_ids'] для обучения LLM
References:
-----------
- Torch Dataset: https://pytorch.org/docs/stable/data.html
- Примеры LLM датасетов в open-source: https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/tokenize.py
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация датасета из списка строк.
Аргументы:
texts (List[str]): Список строк — каждый элемент отдельный обучающий пример.
tokenizer (Any): Токенизатор с методом encode(str, **kwargs) -> List[int].
block_size (int, по умолчанию 128): Желаемая длина результата —
длинные последовательности будут усечены, короткие — дополнены паддингом (pad_token_id или 0).
Особенности:
- Строки не фильтруются и не изменяются внутри датасета.
- Для PAD используется pad_token_id из токенизатора (если есть) либо 0.
- Dict, возвращаемый __getitem__, содержит 'input_ids' и 'labels'.
Пример:
>>> dataset = TextDataset(["hello world", "test string"], tokenizer, block_size=16)
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
for text in texts:
# Кодируем текст в токены
input_ids = tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > block_size:
input_ids = input_ids[:block_size]
else:
# Дополняем pad_token_id
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
"""
Возвращает количество примеров в датасете (длина списка текстов).
Returns:
int: Число примеров в датасете.
"""
return len(self.examples)
def __getitem__(self, idx):
"""
Получить пример из датасета по индексу.
Аргументы:
idx (int): Индекс примера.
Возвращает:
dict: Словарь с тензорами токенов для модели:
- 'input_ids': torch.Tensor shape [block_size], индексы токенов для входа.
- 'labels': torch.Tensor shape [block_size], метки для LM задачи (обычно совпадают с input_ids).
Пример:
>>> item = dataset[7]
>>> assert isinstance(item, dict)
>>> assert item['input_ids'].shape == (block_size,)
>>> assert 'labels' in item
"""
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,124 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
from llm.datasets.text_dataset import TextDataset
class TextWithSpecialTokensDataset(TextDataset):
"""
TextWithSpecialTokensDataset — датасет для языковых моделей с поддержкой специальных токенов (BOS, EOS, PAD).
Назначение:
-----------
- Работает с уже готовым списком строк (не с файлом!).
- Токенизирует строки с помощью заданного токенизатора, вручную вставляет специальные токены (BOS/ EOS/ PAD).
- Обрезает или дополняет каждую последовательность до длины block_size.
Аргументы конструктора:
-----------------------
texts (List[str]): Список обучающих строк (примеров).
tokenizer (Any): Любой токенизатор с методом encode(text, **kwargs).
block_size (int, default=128): Желаемая длина примера (padding/truncation).
add_bos (bool, default=False): Если True, добавляет BOS-токен в начало каждой последовательности.
add_eos (bool, default=False): Если True, добавляет EOS-токен в конец.
Особенности:
------------
- Если pad_token_id не задан — по умолчанию паддит нулями.
- Все returned примеры — dict с 'input_ids' и 'labels' (shape == block_size).
- Обрезание/дополнение учётное: BOS/EOS не "выдавливаются" обрезкой.
- Пример вызова:
>>> texts = ["пример текста", "ещё текст"]
>>> ds = TextWithSpecialTokensDataset(texts, tokenizer, block_size=16, add_bos=True, add_eos=True)
>>> out = ds[0]
>>> assert out['input_ids'].shape == (16,)
References:
-----------
- OpenAI GPT-2 data loader: https://github.com/openai/gpt-2/blob/master/src/encode.py
- HuggingFace data docs: https://huggingface.co/docs/transformers/pad_truncation
"""
def __init__(
self,
texts: List[str],
tokenizer: Any,
block_size: int = 128,
add_bos: bool = False,
add_eos: bool = False,
):
"""
Инициализация датасета с поддержкой специальных токенов.
Args:
texts (List[str]): Список строк (все ваши обучающие примеры).
tokenizer (Any): Токенизатор с методом encode(text, **kwargs).
block_size (int): Длина выходного примера.
add_bos (bool): Добавлять ли BOS токен в начало.
add_eos (bool): Добавлять ли EOS токен в конец.
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
self.add_bos = add_bos
self.add_eos = add_eos
for text in texts:
# Кодируем с специальными токенами
input_ids = tokenizer.encode(
text, add_special_tokens=True, add_bos_token=add_bos, add_eos_token=add_eos
)
# Учитываем специальные токены при обрезке/дополнении
effective_block_size = block_size
if add_bos:
effective_block_size -= 1
if add_eos:
effective_block_size -= 1
if len(input_ids) > effective_block_size:
input_ids = input_ids[:effective_block_size]
# Добавляем специальные токены если нужно
if (
add_bos
and hasattr(tokenizer, "bos_token_id")
and tokenizer.bos_token_id is not None
):
input_ids = [tokenizer.bos_token_id] + input_ids
if (
add_eos
and hasattr(tokenizer, "eos_token_id")
and tokenizer.eos_token_id is not None
):
input_ids = input_ids + [tokenizer.eos_token_id]
# Дополняем до полной длины
pad_token_id = getattr(tokenizer, "pad_token_id", 0)
if len(input_ids) < block_size:
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
"""
Возвращает количество примеров в датасете.
Returns:
int: Размер (len(self.examples)).
"""
return len(self.examples)
def __getitem__(self, idx):
"""
Получить пример с учётом специальных токенов и паддинга.
Args:
idx (int): Индекс в dataset.
Returns:
dict: {'input_ids': torch.Tensor [block_size], 'labels': torch.Tensor [block_size]}
"""
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -0,0 +1,3 @@
from .gemma import Gemma
__all__ = ["Gemma"]

View File

@@ -0,0 +1,346 @@
import torch
import math
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.gemma_decoder import GemmaDecoder
class Gemma(BaseModel):
"""
Gemma — языковая трансформер-модель от Google, с архитектурой, оптимизированной для open-source и research-комьюнити.
Назначение:
-----------
Модель Gemma реализует стек современных декодерных блоков (GemmaDecoder), поддерживает rotary-позиционирование, multi-query self-attention,
эффективный режим генерации (KV-cache), dropout, compact residual connections, базируется на best-practice LLM-инженерии последних лет.
Поддерживает batched-тренировку и inference, генерацию с различными стратегиями выборки (greedy, top-k, top-p), автосохранение.
Архитектурные особенности:
--------------------------
- Stack из N слоёв GemmaDecoder (attention с Multi-Query либо Grouped heads, FFN с GeGLU/SwiGLU)
- RMSNorm или LayerNorm для стабилизации
- Dropout для регуляризации
- Rotary Position Embedding (RoPE) для позиционных кодов
- Выходная проекция (linear → logits) к словарю токенов
- Полная поддержка cache для ускорения autoregressive генерации
Конфиг/Параметры конструктора:
------------------------------
config : dict
Словарь c параметрами модели:
- vocab_size : int — размер словаря
- embed_dim : int — размер скрытого (hidden) пространства
- max_position_embeddings : int — максимальная длина последовательности
- num_layers : int — количество декодерных блоков
- num_q_heads : int — количество attention голов (Queries)
- num_kv_heads : int — количество ключевых/значенческих attention голов
- dropout : float — Dropout率
- ... (доп. гиперпараметры, требуемые GemmaDecoder'ами)
Основные методы:
----------------
- forward(x, use_cache=True, cache=None): выдает батч логитов по токенам, возвращает при необходимости обновленный cache.
- generate(...): автотекстогенерация с greedy, temperature, top-k/p sampling, поддержкой кэша (ускорение inference).
- save(path)/load(path, device): сохранение и загрузка предобученных весов, параметров и состояния.
Пример:
-------
>>> config = {...} # словарь с параметрами
>>> model = Gemma(config)
>>> x = torch.randint(0, config["vocab_size"], (4, 64))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [4, 64, vocab_size]
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.8)
Литература и ссылки:
--------------------
- Gemma: https://ai.google.dev/gemma (официальная страница)
- Разработка и архитектура: https://arxiv.org/abs/2403.07794
- Rotary Embedding: https://arxiv.org/abs/2104.09864
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
"""
def __init__(self, config):
"""
Конструктор класса Gemma.
Позволяет создать объект языковой модели с архитектурой Gemma и
произвольной конфигурацией (гибкая поддержка разных масштабов, ширин, глубин).
Аргументы:
----------
config : dict
Словарь со всеми необходимыми гиперпараметрами и архитектурными детальями модели Gemma.
Ожидаемые ключи (группы параметров):
- vocab_size : int — размер словаря токенов (размерность входа/выхода)
- embed_dim : int — скрытый размер эмбеддинга (hidden dim)
- max_position_embeddings : int — максимальная длина последовательности
- num_layers : int — количество декодерных блоков (глубина стека)
- num_q_heads : int — число attention голов (Query heads)
- num_kv_heads : int — число голов для Key/Value (MultiQuery Attention)
- dropout : float — Dropout для регуляризации
- остальные специфичные для GemmaDecoder'ов параметры
Внутри:
-------
- Инициализируются модули эмбеддинга токенов, позиционного кодирования (RoPE) и Dropout,
стек декодеров (GemmaDecoder(...)), слой финальной нормализации и выходная проекция (linear).
- Все архитектурные параметры напрямую берутся из config.
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 4,
... "dropout": 0.1,
... }
>>> model = Gemma(config)
Примечание:
-----------
- Внимание: значения config должны быть согласованы друг с другом! Например, embed_dim должен быть кратным num_q_heads и т.д.
- Поддерживается дальнейшая кастомизация стека декодеров через ключи в config.
"""
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([GemmaDecoder(
num_q_heads=config["num_q_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через полную модель Gemma.
Трансформирует входную последовательность токенов через стек из декодерных блоков GemmaDecoder.
Возвращает логиты по всем токенам и (при необходимости) кэш attention для быстрой autoregressive-генерации.
Аргументы:
----------
x : torch.Tensor
Входной тензор shape [batch_size, seq_len], содержащий токен-IDs.
use_cache : bool, по умолчанию True
Если True — сохраняет и возвращает KV-кэш attention (ускоряет автогенерацию).
Если False — кэш не используется.
cache : list, optional
(Необязательно) Список/None: с кэшами KV-матриц для каждого слоя (для режима генерации статей/диalogов).
Возвращает:
-----------
tuple:
- logits : torch.Tensor shape [batch_size, seq_len, vocab_size]
Логиты по словарю для каждого токена (input + сколь угодно новых).
- new_cache : list или None
Обновлённый cache (если use_cache=True).
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Используется при обучении и инференсе.
- Если нужно только инференс last-token — используйте logits[:, -1, :].
- При превышении x.shape[1] > max_seq_len выдаёт ValueError.
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
Реализует generation-loop с обновлением attention-кэша для ускорения инференса.
Аргументы:
----------
x : torch.Tensor
Входной тензор с последовательностью токенов (shape [batch_size, seq_len]), который необходимо продолжить.
max_new_tokens : int
Сколько новых токенов сгенерировать (максимум).
do_sample : bool
Если True — сэмплирует следующий токен согласно распределению вероятностей (stochastic), иначе выбирает токен с максимальной вероятностью (greedy).
temperature : float, default=1.0
Параметр для шкалирования распределения вероятностей логитов. Больше 1.0 — больше случайности, меньше 1.0 — более детерминированный (жёсткий) выбор.
top_k : int, optional
Если задано — для сэмплирования учитываются только top_k наиболее вероятных токенов.
top_p : float, optional
Если задано — работают nucleus sampling: учитываются токены, суммарная вероятность которых не превышает top_p.
use_cache : bool, default=True
Если True — для ускорения использует и обновляет attention-кэши (KV-cache).
Возвращает:
-----------
torch.Tensor
Тензор shape [batch_size, seq_len + max_new_tokens] с исходными и сгенерированными токенами (token IDs).
Пример:
-------
>>> out = model.generate(
... x, max_new_tokens=20, do_sample=True, temperature=0.8, top_k=50
... )
>>> print(out.shape) # [batch_size, seq_len+20]
Примечания:
-----------
- Нельзя указывать одновременно top_k и top_p (будет выброшено исключение).
- temperature <= 0 некорректно (будет выброшено исключение).
- Поддержка cache (use_cache=True) значительно ускоряет генерацию длинных последовательностей и позволяет использовать beam search/decoding.
- Для воспроизводимых результатов установите torch.manual_seed перед генерацией.
- Метод возвращает только token_ids, если нужны logits — используйте .forward напрямую.
Литература:
-----------
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Gemma: https://arxiv.org/abs/2403.07794
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -26,56 +26,106 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict
from llm.core.base_model import BaseModel
from llm.core.decoder import Decoder
from llm.core.gpt_decoder import GptDecoder
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.positional_embeddings import PositionalEmbeddings
class GPT(BaseModel):
"""
Original GPT (Generative Pre-trained Transformer) модель.
GPT (Generative Pretrained Transformer) — автогерессивная языковая модель по мотивам оригинального GPT/GPT-2 architecture.
Первая версия трансформерной архитектуры от OpenAI, предназначенная
для генеративного предобучения на текстовых данных.
Назначение:
-----------
- Позволяет предсказывать и генерировать последовательности текста, обучаясь на задаче language modeling (предсказывать следующий токен).
- Класс реализует архитектуру classic Transformer Decoder Stack с masked multi-head attention и token/positional embeddings.
- Используется как базовая модель для генерации, zero-/few-shot, задач обучения с подкреплением и пр.
Args:
config: Словарь конфигурации с параметрами:
- vocab_size: Размер словаря токенов
- embed_dim: Размерность векторных представлений
- num_heads: Количество голов внимания
- num_layers: Количество декодерных слоев
- max_position_embeddings: Максимальная длина последовательности
- dropout: Вероятность dropout
Архитектурные особенности:
--------------------------
- Embedding-слои для токенов (token_embeddings) и позиций (position_embeddings).
- Stack из N декодер-блоков (MultiHeadAttention + FeedForward + residual + LayerNorm).
- Masked self-attention — каждый токен видит только свои и предыдущие, обеспечивая автогерессию.
- LayerNorm до проекции на словарь (pre-LN).
- Поддержка efficient KV кэша — ускоряет autoregressive inference/generation.
Attributes:
_token_embeddings: Слой векторных представлений токенов
_position_embeddings: Слой позиционных эмбеддингов
_decoders: Список декодерных слоев
_norm: Финальный слой нормализации
_linear: Выходной линейный слой
Основные параметры:
-------------------
config: dict в формате {
vocab_size, # размер словаря токенов
embed_dim, # размерность эмбеддинга
num_heads, # количество attention heads
num_layers, # глубина модели (число блоков)
max_position_embeddings,
dropout
}
Формула и поток данных:
-----------------------
x -> token_embeddings -> + position_embeddings -> dropout ->
-> stack([DecoderBlock]) ->
-> LayerNorm ->
-> Linear(out_dim=vocab_size) -> output_logits
Пример использования:
---------------------
>>> gpt = GPT({...})
>>> tokens = torch.tensor([[12, 123, 44]])
>>> logits = gpt(tokens)
>>> generated = gpt.generate(tokens, max_new_tokens=10)
References:
-----------
- Radford et al., "Improving Language Understanding by Generative Pre-Training" (GPT-1, 2018)
https://cdn.openai.com/research-covers/languageunsupervised/language_understanding_paper.pdf
- Original BPE Tokenizer code: https://github.com/openai/gpt-2/blob/master/src/encoder.py
- Формула masked self-attention: Vaswani et al., "Attention is All You Need", 2017
https://arxiv.org/abs/1706.03762
"""
def __init__(self, config):
"""
Инициализация модели GPT.
Args:
-----
config: dict
Параметры архитектуры:
vocab_size: int — размер словаря токенов
embed_dim: int — размерность эмбеддинга
num_heads: int — количество attention-heads
num_layers: int — число Transformer блоков
max_position_embeddings: int — макс. длина последовательности
dropout: float — dropout
Внутри:
-------
- Создаёт слой эмбеддингов, позиционку, стек декодеров, нормализацию, линейную проекцию.
"""
super().__init__(config)
# Инициализация слоев
self._max_seq_len = config["max_position_embeddings"]
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
)
self._position_embeddings = PositionalEmbeddings(
max_seq_len=config["max_position_embeddings"],
emb_size=config["embed_dim"]
max_seq_len=config["max_position_embeddings"], emb_size=config["embed_dim"]
)
self._dropout = nn.Dropout(config["dropout"])
# head_size = emb_size // num_heads
self._decoders = nn.ModuleList([Decoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._decoders = nn.ModuleList(
[
GptDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"],
)
for _ in range(config["num_layers"])
]
)
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
@property
@@ -83,155 +133,165 @@ class GPT(BaseModel):
"""Возвращает максимальную длину последовательности."""
return self._max_seq_len
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
"""Прямой проход через GPT
def forward(
self, x: torch.Tensor, attention_mask=None, use_cache: bool = True, cache: list = None
) -> tuple:
"""
Прямой проход для получения логитов по последовательности токенов.
Args:
x: Входной тензор [batch_size, seq_len]
-----
x : torch.Tensor [batch, seq_len]
Индексы входных токенов.
use_cache : bool, optional
Использовать ли кэш attention (ускоряет инференс, важно для генерации)
cache : list, optional
Список старых KV (key/value)-кэшей
Returns:
Тензор логитов [batch_size, seq_len, vocab_size]
--------
logits: [batch, seq_len, vocab_size] (логиты для softmax по словарю)
new_cache: кэш KV после прохода
"""
# Проверка длины последовательности
if x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}")
raise ValueError(
f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}"
)
# Вычисление start_pos из кэша (если кэш передан)
if cache is not None:
seq_len = 1
# Безопасно извлекаем key_cache для вычисления start_pos
if (
isinstance(cache, (list, tuple))
and len(cache) > 0
and cache[0] is not None
and isinstance(cache[0], (list, tuple))
and len(cache[0]) > 0
and cache[0][0] is not None
and isinstance(cache[0][0], (tuple, list))
and len(cache[0][0]) > 0
):
key_cache, _ = cache[0][0]
start_pos = key_cache.size(1)
else:
start_pos = 0
else:
# Без кэша работаем как раньше
start_pos = 0
seq_len = x.size(1)
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size]
pos_out = self._position_embeddings(
seq_len, start_pos=start_pos
) # [seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
out = self._dropout(
tok_out + pos_out.unsqueeze(0)
) # [batch, seq_len, emb_size]
# Стек декодеров
for decoder in self._decoders:
out = decoder(out)
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
return self._linear(out) # [batch, seq_len, vocab_size]
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
logits = self._linear(out) # [batch, seq_len, vocab_size]
# def forward(self, input_ids, attention_mask=None):
# B, T = input_ids.size()
# pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
#
# x = self.token_emb(input_ids) + self.pos_emb(pos)
#
# for block in self.blocks:
# x = block(x, attention_mask)
#
# x = self.ln_f(x)
# logits = self.head(x)
# return logits
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(self,
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
**kwargs # Игнорируем остальные параметры
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""Авторегрессивная генерация текста.
"""
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
top-k и nucleus (top-p) sampling.
Параметры:
x: Входной тензор с индексами токенов формы [batch_size, seq_len],
где batch_size - размер батча, seq_len - длина последовательности.
max_new_tokens: Максимальное количество новых токенов для генерации.
do_sample: Флаг выбора режима генерации:
- True: вероятностное сэмплирование
- False: жадный поиск (argmax)
temperature: Параметр температуры для сэмплирования:
- >1.0 - более случайные результаты
- 1.0 - нейтральное значение
- <1.0 - более предсказуемые результаты
Должна быть > 0 (по умолчанию: 1.0)
top_k: Если задан (и do_sample=True), используется top-k сэмплирование:
- Выбираются только top_k самых вероятных токенов
- Остальным токенам устанавливается вероятность 0
- None: отключено (по умолчанию)
top_p: Если задан (и do_sample=True), используется nucleus (top-p) сэмплирование:
- Выбираются токены с кумулятивной вероятностью ≤ top_p
- Гарантируется, что хотя бы один токен остаётся (даже если его вероятность > top_p)
- None: отключено (по умолчанию)
- Должен быть в диапазоне (0, 1]
Аргументы:
x (torch.Tensor): Входной тензор с индексами токенов, форма [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятностное сэмплирование; если False — жадная генерация (argmax).
temperature (float): Температура для управления случайностью (>0, влияет только если do_sample=True).
>1.0 — более случайно, <1.0 — более детерминированно.
top_k (int, опц.): При do_sample=True ограничивает выбор top_k самых вероятных токенов (top-k sampling).
top_p (float, опц.): При do_sample=True включает top-p (nucleus) sampling: кумулятивная вероятность ≤ top_p.
Должно быть в (0, 1].
attention_mask (torch.Tensor, опц.): Внешняя маска внимания (для совместимости с HuggingFace).
**kwargs: Игнорируются.
Возвращает:
torch.Tensor: Тензор с расширенной последовательностью токенов формы
[batch_size, seq_len + max_new_tokens]
torch.Tensor: Последовательность токенов [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если входная последовательность длиннее max_seq_len
ValueError: Если temperature <= 0
ValueError: Если одновременно заданы top_k и top_p
ValueError: Если top_k задан и ≤ 0
ValueError: Если top_p задан и не в диапазоне (0, 1]
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p вне диапазона (0, 1].
Примеры:
>>> # Жадная генерация
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
>>>
>>> # Вероятностная генерация с top-k
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_k=50)
>>>
>>> # Nucleus sampling (top-p)
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, top_p=0.9)
>>>
>>> # Жадная (детерминированная) генерация
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=False)
>>> # Вероятностная генерация с температурой
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=0.8)
>>> # Top-k сэмплирование
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, top_p=0.92)
>>> # Комбинация температуры и top-k
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True,
... temperature=0.7, top_k=50)
>>> output = model.generate(input_ids, max_new_tokens=12, do_sample=True, temperature=1.0, top_k=100)
Примечания:
1. Для детерминированных результатов в режиме сэмплирования
зафиксируйте random seed (torch.manual_seed).
2. Температура влияет только на режим сэмплирования (do_sample=True).
3. Одновременное использование top_k и top_p запрещено.
4. При do_sample=False параметры top_k, top_p и temperature игнорируются.
- Для детерминированных выборок зафиксируйте random seed через torch.manual_seed.
- Параметры temperature, top_k, top_p применимы только если do_sample=True.
- Одновременное использование top_k и top_p не допускается.
- Модель всегда возвращает тензор индексов токенов; для получения логитов используйте прямой вызов forward.
Args:
x (torch.Tensor): Входной тензор с индексами токенов формы [batch_size, seq_len],
где batch_size - размер батча, seq_len - длина последовательности.
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Флаг выбора режима генерации:
- True: вероятностное сэмплирование
- False: жадный поиск (argmax)
temperature (float): Параметр температуры для сэмплирования:
- >1.0 - более случайные результаты
- 1.0 - нейтральное значение
- <1.0 - более предсказуемые результаты
Должна быть > 0 (по умолчанию: 1.0)
Returns:
torch.Tensor: Тензор с расширенной последовательностью токенов формы
[batch_size, seq_len + max_new_tokens]
Raises:
ValueError: Если входная последовательность длиннее max_seq_len
ValueError: Если temperature <= 0
Examples:
>>> # Жадная генерация
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=False)
>>>
>>> # Вероятностная генерация с температурой
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=0.7)
>>>
>>> # Более случайная генерация
>>> output = model.generate(input_ids, max_new_tokens=10, do_sample=True, temperature=1.5)
Note:
Для детерминированных результатов в режиме сэмплирования
зафиксируйте random seed (torch.manual_seed).
Температура влияет только на режим сэмплирования (do_sample=True).
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
"""
cache = None
for _ in range(max_new_tokens):
# 1. Обрезаем вход, если последовательность слишком длинная
x_cond = x[:, -self._max_seq_len:]
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
logits = self.forward(x_cond)
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
# 3. Берем логиты для последнего токена
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
@@ -250,9 +310,14 @@ class GPT(BaseModel):
vocab_size = logits_scaled.size(-1)
# создаём маску: True, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, 'bool') else 0) # False там, где top-k индексы
masked_logits[mask] = float('-inf')
mask = torch.ones_like(
logits_scaled,
dtype=torch.bool if hasattr(torch, "bool") else torch.uint8,
)
mask.scatter_(
1, topk_indices, False if hasattr(torch, "bool") else 0
) # False там, где top-k индексы
masked_logits[mask] = float("-inf")
logits_scaled = masked_logits
@@ -260,36 +325,42 @@ class GPT(BaseModel):
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
sorted_probs, sorted_indices = torch.sort(
probs, descending=True, dim=-1
)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p) # [B, vocab_size]
sorted_mask = cum_probs <= top_p # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из False
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, 'bool') else torch.uint8)
mask = torch.zeros_like(
probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8
)
# Устанавливаем True в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
logits_scaled[~mask] = float("-inf")
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
next_token = torch.argmax(
probs, dim=-1, keepdim=True
) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
# def generate(self, input_ids, max_length=50):
# for _ in range(max_length):
# logits = self.forward(input_ids)

View File

@@ -27,81 +27,144 @@ from llm.core.positional_embeddings import PositionalEmbeddings
from llm.core.cached_decoder import CachedDecoder
from llm.core.feed_forward import FeedForward
class GPT2(BaseModel):
"""
GPT2 — автогерессивная языковая модель, архитектура Transformer, предложенная OpenAI.
GPT-2 — масштабируемый автогерессивный языковой трансформер второго поколения от OpenAI (2019).
Научная суть:
- Масштабируемый автогерессивный трансформер для предсказания токенов слева направо.
- Главное отличие от классической GPT: порядок layer normalization ПЕРЕД attention и FFN.
- Используется GELU, efficient KV-cache, несет наследие классической GPT, но делает архитектуру глубже/шире.
Назначение:
-----------
- Позволяет предсказывать и порождать последовательности текста по одному токену, будучи обученным на задаче language modeling.
- Модель реализует архитектуру decoder-only Transformer с Pre-LN (LayerNorm перед attention и FFN).
- Используется для генерации, обучения с подкреплением для RLHF, zero/few-shot inference, чат-ботов и др.
Args:
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
Архитектурные особенности:
--------------------------
- Token и positional embeddings (learnable, как в GPT-2 оригинале).
- Stack из N блоков Decoder (MultiHeadAttention с causal mask, Residual, Pre-LayerNorm, GELU FFN).
- KV attention-кэш (ускоряет autoregressive generation, критически важно для LLM).
- Использует GELU как функцию активации.
- Поддержка dropout на каждом этапе.
Основные параметры:
-------------------
config: dict — параметры модели:
vocab_size, # размер словаря токенов
embed_dim, # размерность эмбеддинга
num_heads, # количество attention голов
num_layers, # глубина модели (число блоков)
max_position_embeddings,
dropout
Процессинг:
-----------
x (индексы токенов) → token_embeddings + position_embeddings → dropout
→ stack Decoder blocks (masked attention, pre-LN)
→ LayerNorm
→ Linear(out_dim=vocab_size) → выходные логиты
Пример использования:
>>> model = GPT2({"vocab_size": 50257, ...})
>>> logits = model(input_ids)
>>> out = model.generate(input_ids, max_length=20)
---------------------
>>> gpt2 = GPT2({...})
>>> logits = gpt2(input_ids)
>>> output = gpt2.generate(input_ids, max_new_tokens=20, do_sample=True)
References:
-----------
- Radford et al., "Language Models are Unsupervised Multitask Learners" (GPT-2, 2019): https://cdn.openai.com/better-language-models/language-models.pdf
- HuggingFace GPT-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
- Репликация в NanoGPT: https://github.com/karpathy/nanoGPT
"""
def __init__(self, config):
"""
Инициализация GPT-2.
Args:
config (dict): Параметры архитектуры:
vocab_size: int — размер словаря
embed_dim: int — размерность эмбеддинга
num_heads: int — количество attention-голов
num_layers: int — количество декодер-блоков
max_position_embeddings: максимальная длина последовательности
dropout: float — dropout
Внутри:
-------
- Создаёт токеновые и позиционные эмбеддинги, стек декодеров, финальный LayerNorm и линейную проекцию в словарь.
"""
super().__init__(config)
# Инициализация слоев
self._max_seq_len = config["max_position_embeddings"]
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
)
self._position_embeddings = PositionalEmbeddings(
max_seq_len=config["max_position_embeddings"],
emb_size=config["embed_dim"]
max_seq_len=config["max_position_embeddings"], emb_size=config["embed_dim"]
)
self._dropout = nn.Dropout(config["dropout"])
# head_size = emb_size // num_heads
self._decoders = nn.ModuleList([CachedDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=FeedForward(
emb_size=config["embed_dim"],
dropout=config["dropout"],
activation="gelu"
),
max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._decoders = nn.ModuleList(
[
CachedDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=FeedForward(
emb_size=config["embed_dim"],
dropout=config["dropout"],
activation="gelu",
),
max_seq_len=config["max_position_embeddings"],
dropout=config["dropout"],
)
for _ in range(config["num_layers"])
]
)
self._norm = nn.LayerNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
def forward(
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
) -> tuple:
"""
Прямой проход GPT2:
- Все слои работают как autoregressive transformer (masked self-attention).
- При use_cache=True возвращает также новый кэш KV attention (ускоряет генерацию).
Прямой проход для batch of sequences (получение логитов по токенам).
Args:
x (Tensor): Входные индексы токенов [batch, seq_len]
use_cache (bool): Кэшировать KV attention для ускорения autoregressive генерации
cache (list|None): Список KV-кэшей от предыдущих шагов (или None)
x (torch.Tensor): Входной тензор с токенами [batch, seq_len]
use_cache (bool): Использовать/возвращать кэш KV attention (ускоряет генерацию)
cache (list / None): Внешний кэш KV attention (передаётся при генерации)
Returns:
logits (Tensor): [batch, seq_len, vocab_size]
cache (list): новый кэш если use_cache=True, иначе None
logits: torch.Tensor [batch, seq_len, vocab_size]
new_cache: новый кэш KV attention (или None)
Пример:
>>> logits, cache = model.forward(x, use_cache=True)
>>> logits, cache = gpt2(x, use_cache=True)
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
raise ValueError(
f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}"
)
# Вычисление start_pos из кэша (если кэш передан)
if cache is not None:
# При кэше обрабатываем только один токен (последний)
seq_len = 1
# Вычисляем start_pos из самого нижнего уровня кэша
if cache and cache[0] and cache[0][0]:
key_cache, _ = cache[0][0] # Первый декодер, первая голова
start_pos = key_cache.size(1) # cache_len
# Безопасно извлекаем key_cache для вычисления start_pos
if (
isinstance(cache, (list, tuple))
and len(cache) > 0
and cache[0] is not None
and isinstance(cache[0], (list, tuple))
and len(cache[0]) > 0
and cache[0][0] is not None
and isinstance(cache[0][0], (tuple, list))
and len(cache[0][0]) > 0
):
key_cache, _ = cache[0][0]
start_pos = key_cache.size(1)
else:
start_pos = 0
else:
@@ -111,10 +174,14 @@ class GPT2(BaseModel):
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
pos_out = self._position_embeddings(seq_len, start_pos=start_pos) # [seq_len, emb_size]
pos_out = self._position_embeddings(
seq_len, start_pos=start_pos
) # [seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out + pos_out.unsqueeze(0)) # [batch, seq_len, emb_size]
out = self._dropout(
tok_out + pos_out.unsqueeze(0)
) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
@@ -138,32 +205,69 @@ class GPT2(BaseModel):
else:
return (logits, None)
def generate(self,
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Генерация текста с использованием autoregressive трансформера (GPT2).
Поддерживаются greedy, sampling, top-k/top-p (nucleus sampling) режимы.
Args:
x (Tensor[int]): начальная последовательность [batch, seq_len]
max_new_tokens (int): сколько токенов сгенерировать
do_sample (bool): использовать стохастическое сэмплирование вместо жадного выбора
temperature (float): коэффициент сглаживания логитов (низкое — более консервативно)
top_k (int|None): ограничить выбор top-k наиболее вероятных токенов
top_p (float|None): ограничить суммарную вероятность (nucleus sampling)
use_cache (bool): ускорять autoregressive инференс
Returns:
output (Tensor[int]): сгенерированный тензор токенов [batch, seq_len + max_new_tokens]
Пример:
>>> prompt = tokenizer.encode('Привет', return_tensors="pt")
>>> output = model.generate(prompt, max_new_tokens=20, do_sample=True)
>>> print(tokenizer.decode(output[0]))
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
Аргументы:
x (torch.Tensor): Входной тензор с индексами токенов [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Режим генерации:
- True: вероятностное сэмплирование (random sampling)
- False: жадный (greedy) поиск (выбор argmax на каждом шаге)
temperature (float): Температура распределения (>0, по умолчанию 1.0).
- >1.0 — генерация более "творческая"/приподнятая вероятность "редких" токенов;
- <1.0 — более предсказуемый и суженный выбор.
top_k (int, опционально): Если задан, sampling только из top_k самых вероятных токенов (top-k sampling).
top_p (float, опционально): Если задан, sampling только из токенов, кумулятивная вероятность которых ≤ top_p (nucleus/top-p sampling, см. Holtzman et al., 2019).
use_cache (bool, по умолчанию True): Использовать кэш attention KV для ускорения авторегрессии.
Возвращает:
torch.Tensor: Тензор индексов токенов [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее максимальной длины (max_seq_len).
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры использования:
>>> # Жадная генерация
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=False)
>>> # Сэмплирование с температурой
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, top_p=0.92)
>>> # Комбинация температуры и top-k
>>> output = model.generate(input_ids, max_new_tokens=20, do_sample=True, temperature=0.7, top_k=40)
Примечания:
- Для детерминированных результатов используйте torch.manual_seed.
- temperature, top_k, top_p работают только при do_sample=True.
- Только один из top_k/top_p может быть задан одновременно.
- Метод всегда возвращает индексы токенов (ids); для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
- Оригинальная статья GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
"""
cache = None
@@ -198,26 +302,27 @@ class GPT2(BaseModel):
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf')
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
sorted_probs, sorted_indices = torch.sort(
probs, descending=True, dim=-1
)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8)
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
@@ -226,13 +331,14 @@ class GPT2(BaseModel):
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
next_token = torch.argmax(
probs, dim=-1, keepdim=True
) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]

View File

@@ -10,77 +10,121 @@ from llm.core.rope import RoPE
from llm.core.cached_decoder import CachedDecoder
class Llama(BaseModel):
"""
LLaMA (Large Language Model Meta AI) — высокоэффективная масштабируемая языковая модель, разработанная Meta AI Research.
LLaMA — автогерессивная большая языковая модель (Large Language Model from Meta, 2023).
Ключевые идеи:
- Rotary Positional Encoding (RoPE) вместо стандартных позиционных эмбеддингов
- RMSNorm (Root Mean Square LayerNorm) вместо LayerNorm
- SwiGLU как нелинейность вместо ReLU/GELU (больше экспрессивности)
- Глубокая оптимизация inference (большая экономия памяти и FLOPs)
Подробнее: https://arxiv.org/abs/2302.13971
Назначение:
-----------
- Модель реализует архитектуру decoder-only Transformer с современными "индустриальными" трюками (RMSNorm, SwiGLU, RoPE, GQA).
- Предназначена для генерации текста, чат-ботов, zero-/few-shot вывода, fine-tune в стиле RLHF, transfer learning и исследований в LLM.
Архитектурные особенности:
--------------------------
- Токеновые эмбеддинги и позиционное кодирование с помощью Rotary Position Embedding (RoPE, https://arxiv.org/abs/2104.09864).
- Stack из num_layers современных декодеров с Grouped Query Attention (GQA: num_q_heads > num_kv_heads) для эффективной генерации.
- FeedForward блоки с SwiGLU (см. https://arxiv.org/abs/2002.05202).
- Нормализация RMSNorm перед каждым sub-layer (вот почему "Pre-RMSNorm").
- Кэширование attention (KV cache) для быстрой autoregressive генерации.
- Нет bias в Linear слоях, нет Dropout внутри attention.
Аргументы конструктора:
-----------------------
config: dict с требуемыми ключами:
vocab_size: int — размер словаря токенов
embed_dim: int — размерность эмбеддингов
num_q_heads: int — количество query-голов в attention (обычно больше num_kv_heads)
num_kv_heads: int — количество key/value-голов
num_layers: int — число слоёв-декодеров
max_position_embeddings: int — максимальная длина последовательности
window_size: int (optional) — размер sliding window для attention
dropout: float (обычно 0.0 или очень мал)
...
Пример использования:
---------------------
>>> llama = LLaMA({...})
>>> tokens = torch.tensor([[100, 56, 8]])
>>> logits = llama(tokens)
>>> out = llama.generate(tokens, max_new_tokens=10, do_sample=True, top_k=50)
References:
-----------
- "LLaMA: Open and Efficient Foundation Language Models" (Touvron et al., 2023): https://arxiv.org/abs/2302.13971
- "Grouped-Query Attention": https://arxiv.org/abs/2307.09288
- "RoFormer: Enhanced Transformer with Rotary Position Embedding": https://arxiv.org/abs/2104.09864
- Discussion of efficient LLMs: https://huggingface.co/blog/mistral
Args:
config (dict): параметры архитектуры (vocab_size, embed_dim, num_heads, num_layers, max_position_embeddings, dropout)
Пример:
>>> model = Llama({...})
>>> logits, cache = model(input_ids, use_cache=True)
>>> out = model.generate(input_ids, max_new_tokens=20)
"""
def __init__(self,config):
def __init__(self, config):
"""
Инициализация LLaMA.
Args:
config (dict): Параметры архитектуры, см. docstring класса.
Внутри:
-------
- Создаёт Embedding-слой, Rotary Position Embeddings (RoPE), стек слоёв с GQA, RMSNorm, SwiGLU.
- Финальный слой нормализации и проекции на vocabulary.
"""
super().__init__(config)
# Инициализация слоев
self._max_seq_len = config["max_position_embeddings"]
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
vocab_size=config["vocab_size"], emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_heads"],
max_seq_len=config["max_position_embeddings"]
max_seq_len=config["max_position_embeddings"],
)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([CachedDecoder(
norm_layer=RMSNorm,
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=SwiGLU(
emb_size=config["embed_dim"],
dropout=config["dropout"],
),
max_seq_len=config["max_position_embeddings"],
rope=self._position_embeddings,
dropout=config["dropout"],
) for _ in range(config["num_layers"])])
self._decoders = nn.ModuleList(
[
CachedDecoder(
norm_layer=RMSNorm,
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
feed_forward_layer=SwiGLU(
emb_size=config["embed_dim"],
dropout=config["dropout"],
),
max_seq_len=config["max_position_embeddings"],
rope=self._position_embeddings,
dropout=config["dropout"],
)
for _ in range(config["num_layers"])
]
)
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
def forward(
self, x: torch.Tensor, use_cache: bool = True, cache: list = None
) -> tuple:
"""
Прямой проход через LLaMA (inference/train): авторегрессионное предсказание токенов.
Прямой проход: возвращает logits (и возможно обновлённый cache) по входным токенам.
Args:
x (Tensor[int]): входные токены [batch, seq_len]
use_cache (bool): использовать ли кэш (ускоряет генерацию)
cache (list|None): ключи и значения attention для autoregressive режима
x (torch.Tensor): [batch, seq_len] — индексы токенов, shape [batch, seq_len]
use_cache (bool): использовать механизм KV cache (ускоряет autoregressive generation)
cache (list or None): предыдущий кэш, если нужен
Returns:
logits (Tensor): [batch, seq_len, vocab_size]
new_cache (list|None): новый кэш attention (если use_cache)
Пример:
>>> logits, cache = model.forward(x, use_cache=True)
logits: torch.Tensor [batch, seq_len, vocab_size]
new_cache: новый кэш attention (или None)
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
raise ValueError(
f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}"
)
# Вычисление start_pos из кэша (если кэш передан)
#if cache is not None:
# if cache is not None:
# # При кэше обрабатываем только один токен (последний)
# seq_len = 1
# # Вычисляем start_pos из самого нижнего уровня кэша
@@ -89,14 +133,14 @@ class Llama(BaseModel):
# start_pos = key_cache.size(1) # cache_len
# else:
# start_pos = 0
#else:
# else:
# # Без кэша работаем как раньше
# start_pos = 0
# seq_len = x.size(1)
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
@@ -123,35 +167,63 @@ class Llama(BaseModel):
else:
return (logits, None)
def generate(self,
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Генерация текста c помощью LLaMA (autoregressive Transformer).
Поддерживается:
- greedy и вероятностное сэмплирование (top-k, top-p, temperature)
- кэш attention для ускорения генерации длинных последовательностей
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
Args:
x (Tensor[int]): начальная последовательность [batch, seq_len]
max_new_tokens (int): сколько новых токенов сгенерировать
do_sample (bool): использовать стохастику (True) или жадный выбор (False)
temperature (float): масштаб для softmax (важно для sampling)
top_k (int|None): ограничение на количество кандидатов (top-k sampling)
top_p (float|None): nucleus sampling
use_cache (bool): ускоряет autoregressive при длинной генерации
Returns:
output (Tensor[int]): [batch, seq_len + max_new_tokens]
Пример:
>>> prompt = tokenizer.encode('Meta AI', return_tensors="pt")
>>> generated = model.generate(prompt, max_new_tokens=30, do_sample=True)
>>> print(tokenizer.decode(generated[0]))
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Использовать вероятностное сэмплирование (True) или жадный режим (False, argmax).
temperature (float): Температура (сглаживание распределения вероятностей, >0; по умолчанию 1.0).
>1.0 — менее предсказуемые, более разнообразные выборки.
<1.0 — более строгие, консервативные выборки.
top_k (int, опционально): Top-k сэмплирование (ограничение выбора k самыми вероятными токенами).
top_p (float, опционально): Nucleus (top-p) sampling (срез по кумулятивной вероятности ≤ top_p, см. Holtzman et al., 2019).
use_cache (bool, по умолчанию True): Использовать KV-кэш для ускорения генерации.
Возвращает:
torch.Tensor: Последовательность токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее максимально допустимой длины (max_seq_len модели).
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Строго жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Вероятностная генерация с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.7)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus)
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Комбинация температуры и top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- temperature, top_k, top_p применяются только если do_sample=True.
- Одновременное использование top_k и top_p запрещено.
- Для воспроизводимых результатов зафиксируйте seed через torch.manual_seed.
- Возвращается только индексы токенов; для получения вероятностей используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p): https://arxiv.org/abs/1904.09751
- LLaMA: https://arxiv.org/abs/2302.13971
"""
cache = None
@@ -186,26 +258,27 @@ class Llama(BaseModel):
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.uint8)
mask.scatter_(1, topk_indices, 0) # 0 там, где top-k индексы
masked_logits[mask.byte()] = float('-inf')
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
sorted_probs, sorted_indices = torch.sort(
probs, descending=True, dim=-1
)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).byte() # [B, vocab_size]
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = 1
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.uint8)
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
@@ -214,20 +287,19 @@ class Llama(BaseModel):
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
next_token = torch.argmax(
probs, dim=-1, keepdim=True
) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -0,0 +1,3 @@
from .mistral import Mistral
__all__ = ["Mistral"]

View File

@@ -0,0 +1,276 @@
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rms_norm import RMSNorm
from llm.core.rope import RoPE
from llm.core.mistral_decoder import MistralDecoder
class Mistral(BaseModel):
"""
Mistral — автогерессивная языковая LLM-архитектура (2023, Mistral AI) для быстрого и качественного моделирования текста.
Назначение:
-----------
- Модель построена на базе decoder-only Transformer с важными оптимизациями: GQA (Grouped Query Attention), RoPE, SwiGLU, RMSNorm, sliding window attention.
- Поддерживает autoregressive generation (step-by-step текст), обучение и inference на длинных последовательностях.
- Используется в современных open-source LLM: Mistral-7B, Mixtral-8x7B и др.
Архитектурные особенности:
--------------------------
- Токеновые эмбеддинги (TokenEmbeddings) и позиционное кодирование через RoPE (rotary position embedding).
- Stack из num_layers декодеров с Grouped Query Attention (раздельное число query/key heads для оптимизации памяти).
- Sliding Window Attention Mask — позволяет ускорять обработку длинных текстов, ограничивая область внимания для каждого токена (как в оригинальном Mistral).
- SwiGLU FeedForward-блоки и RMSNorm.
- Dropout (регуляризация).
- Кэширование attention (KV cache) для быстрой генерации токенов по одному.
Аргументы конструктора:
-----------------------
config (dict): параметры модели (см. документацию Mistral):
vocab_size: int — размер словаря токенов
embed_dim: int — размерность эмбеддингов
num_q_heads: int — количество query-голов (обычно больше num_kv_heads)
num_kv_heads: int — количество key/value attention-голов
num_layers: int — число слоёв-декодеров
max_position_embeddings: int — максимальная длина последовательности
window_size: int — размер sliding window attention
dropout: float — dropout (обычно очень мал или 0)
...
Пример использования:
---------------------
>>> model = Mistral({...})
>>> tokens = torch.tensor([[100, 56, 8]])
>>> logits = model(tokens)
>>> generated = model.generate(tokens, max_new_tokens=16, do_sample=True, top_k=50)
References:
-----------
- "Mistral: Fast and Efficient Dense and Mixture of Experts Transformer Models" (2023): https://arxiv.org/abs/2310.06825
- LLaMA v2 & Grouped-Query Attention: https://arxiv.org/abs/2307.09288
- Оригинальное обсуждение архитектуры: https://huggingface.co/blog/mistral
"""
def __init__(self, config):
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([MistralDecoder(
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
window_size=config["window_size"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mistral: возвращает логиты для токенов и (опционально) кэш attention для ускорения autoregressive генерации.
Аргументы:
x (torch.Tensor): Входной тензор с токенами (shape [batch_size, seq_len]), где значения — индексы токенов.
use_cache (bool, по умолчанию True): Возвращать ли новый KV attention-кэш для последующей генерации.
cache (list or None): Предыдущий кэш attention (или None для полного прохода без накопления кэша).
Возвращает:
logits (torch.Tensor): Тензор логитов shape [batch_size, seq_len, vocab_size] — вероятностное распределение по словарю для каждого токена.
new_cache (list or None): Новый кэш KV attention-слоев (или None, если use_cache=False).
Исключения:
ValueError: Если длина последовательности превышает максимальную (max_seq_len), когда не используется кэш.
Пример:
>>> logits, cache = model.forward(input_ids, use_cache=True)
>>> probabilities = torch.softmax(logits, dim=-1)
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -0,0 +1,3 @@
from .mixtral import Mixtral
__all__ = ["Mixtral"]

View File

@@ -0,0 +1,361 @@
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.mixtral_decoder import MixtralDecoder
class Mixtral(BaseModel):
"""
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
Описание:
---------
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
Архитектурные особенности:
--------------------------
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
Аргументы конструктора:
----------------------
config : dict
Словарь-конфиг с основными гиперпараметрами модели:
- vocab_size : int — размер словаря токенов
- embed_dim : int — размер скрытого пространства
- max_position_embeddings : int — макс. длина последовательности
- num_layers : int — количество декодерных блоков в стеке
- num_q_heads : int — число query-голов в attention
- num_kv_heads : int — число kv-голов в attention
- num_experts : int — число MoE-экспертов
- top_k_experts : int — сколько экспертов активировать на токен
- dropout : float — вероятность Dropout
- window_size : int — размер окна внимания
Основные методы:
----------------
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
Пример:
-------
>>> config = {...} # dict с параметрами
>>> model = Mixtral(config)
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [2, 16, vocab_size]
>>> # Генерация
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
Литература:
-----------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Switch Transformer: https://arxiv.org/abs/2101.03961
- GShard: https://arxiv.org/abs/2006.16668
- RoPE: https://arxiv.org/abs/2104.09864
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self, config):
"""
Конструктор класса Mixtral.
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
Аргументы:
----------
config : dict
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
vocab_size (int): Размер словаря токенов.
embed_dim (int): Размер скрытого пространства (эмбеддингов).
max_position_embeddings (int): Максимальная длина токенной последовательности.
num_layers (int): Количество декодерных блоков (слоёв) в модели.
num_q_heads (int): Число query-голов (attention heads).
num_kv_heads (int): Число key-value голов (attention heads).
num_experts (int): Количество экспертов в каждом MoE-блоке.
top_k_experts (int): Сколько экспертов активируется для одного токена.
dropout (float): Dropout для регуляризации.
window_size (int): Размер окна внимания (Attention Window).
Внутри:
-------
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 8,
... "num_experts": 8,
... "top_k_experts": 2,
... "dropout": 0.1,
... "window_size": 256,
... }
>>> model = Mixtral(config)
Примечания:
-----------
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
"""
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([MixtralDecoder(
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
num_experts=config["num_experts"],
top_k_experts=config["top_k_experts"],
window_size=config["window_size"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mixtral.
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
Аргументы:
----------
x : torch.Tensor
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
use_cache : bool, по умолчанию True
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
Если False — attention cache не используется.
cache : list, optional
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
Возвращает:
-----------
tuple:
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
- new_cache : list или None — обновлённый cache, если используется.
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -82,7 +82,9 @@ class BaseTokenizer(ABC):
List[str]: Список токенов
"""
token_ids = self.encode(text, **kwargs)
return [self.inverse_vocab.get(token_id, self.unk_token) for token_id in token_ids]
return [
self.inverse_vocab.get(token_id, self.unk_token) for token_id in token_ids
]
def get_vocab(self) -> Dict[str, int]:
"""Возвращает словарь токенизатора."""
@@ -120,16 +122,16 @@ class BaseTokenizer(ABC):
filepath: Путь для сохранения
"""
config = {
'vocab': self.vocab,
'vocab_size': self.vocab_size,
'pad_token': self.pad_token,
'unk_token': self.unk_token,
'bos_token': self.bos_token,
'eos_token': self.eos_token,
'tokenizer_type': self.__class__.__name__
"vocab": self.vocab,
"vocab_size": self.vocab_size,
"pad_token": self.pad_token,
"unk_token": self.unk_token,
"bos_token": self.bos_token,
"eos_token": self.eos_token,
"tokenizer_type": self.__class__.__name__,
}
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
@@ -143,17 +145,17 @@ class BaseTokenizer(ABC):
Returns:
BaseTokenizer: Загруженный токенизатор
"""
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
config = json.load(f)
# Создаем экземпляр токенизатора
tokenizer = cls()
tokenizer.vocab = config['vocab']
tokenizer.vocab_size = config['vocab_size']
tokenizer.pad_token = config['pad_token']
tokenizer.unk_token = config['unk_token']
tokenizer.bos_token = config['bos_token']
tokenizer.eos_token = config['eos_token']
tokenizer.vocab = config["vocab"]
tokenizer.vocab_size = config["vocab_size"]
tokenizer.pad_token = config["pad_token"]
tokenizer.unk_token = config["unk_token"]
tokenizer.bos_token = config["bos_token"]
tokenizer.eos_token = config["eos_token"]
# Создаем обратный словарь
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}

View File

@@ -1,428 +0,0 @@
"""
BPE (Byte Pair Encoding) токенизатор.
Реализация алгоритма BPE для токенизации текста.
"""
import re
from collections import defaultdict, Counter
from typing import List, Dict, Tuple, Optional
from .base_tokenizer import BaseTokenizer
class BPETokenizer(BaseTokenizer):
"""
BPE токенизатор для обработки текста.
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
Примеры использования:
>>> tokenizer = BPETokenizer()
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
>>> tokens = tokenizer.encode("новый текст")
>>> text = tokenizer.decode(tokens)
"""
def __init__(self):
super().__init__()
self.merges: Dict[Tuple[str, str], int] = {}
self.pattern = r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
self.compiled_pattern = re.compile(self.pattern, re.UNICODE)
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
"""
Обучение BPE токенизатора на текстах.
Args:
texts: Список текстов для обучения
vocab_size: Желаемый размер словаря
**kwargs: Дополнительные параметры
- min_frequency: Минимальная частота для мерджа
- special_tokens: Список специальных токенов
"""
# Инициализация базового словаря
self._initialize_vocab()
# Добавляем специальные токены если указаны
special_tokens = kwargs.get('special_tokens', [self.pad_token, self.unk_token, self.bos_token, self.eos_token])
self.add_special_tokens(special_tokens)
# Предобработка текстов
words = self._preprocess_texts(texts)
# Получаем начальные токены
vocab = self._get_initial_vocab(words)
# Выполняем BPE мерджи
self._perform_merges(vocab, vocab_size, kwargs.get('min_frequency', 2))
# Строим финальный словарь
self._build_final_vocab()
def _initialize_vocab(self):
"""Инициализирует базовый словарь."""
self.vocab.clear()
self.inverse_vocab.clear()
self.merges.clear()
self.vocab_size = 0
def _preprocess_texts(self, texts: List[str]) -> List[List[str]]:
"""
Предобработка текстов для обучения.
Args:
texts: Список текстов
Returns:
List[List[str]]: Предобработанные слова
"""
words = []
for text in texts:
# Базовая нормализация
text = text.lower().strip()
# Токенизация на слова
tokens = self.compiled_pattern.findall(text)
words.append(tokens)
return words
def _get_initial_vocab(self, words: List[List[str]]) -> Dict[str, int]:
"""
Создает начальный словарь из символов.
Args:
words: Список токенизированных текстов
Returns:
Dict[str, int]: Начальный словарь частот
"""
vocab = Counter()
for word_list in words:
for word in word_list:
# Разбиваем слово на символы и добавляем специальный символ конца слова
chars = list(word) + ['</w>']
vocab.update([''.join(chars[i:i+1]) for i in range(len(chars))])
return vocab
def _perform_merges(self, vocab: Dict[str, int], target_vocab_size: int, min_frequency: int):
"""
Выполняет BPE мерджи до достижения целевого размера словаря.
Args:
vocab: Начальный словарь
target_vocab_size: Целевой размер словаря
min_frequency: Минимальная частота для мерджа
"""
current_vocab_size = len(vocab) + len(self.vocab)
while current_vocab_size < target_vocab_size:
# Находим наиболее частую пару
pairs = self._get_stats(vocab)
if not pairs:
break
best_pair = max(pairs, key=pairs.get)
if pairs[best_pair] < min_frequency:
break
# Выполняем мердж
vocab = self._merge_vocab(vocab, best_pair)
self.merges[best_pair] = len(self.merges)
current_vocab_size += 1
def _get_stats(self, vocab: Dict[str, int]) -> Dict[Tuple[str, str], int]:
"""
Собирает статистику по парам символов.
Args:
vocab: Словарь токенов
Returns:
Dict[Tuple[str, str], int]: Частоты пар
"""
pairs = defaultdict(int)
for word, freq in vocab.items():
symbols = word.split()
for i in range(len(symbols) - 1):
pairs[symbols[i], symbols[i + 1]] += freq
return pairs
def _merge_vocab(self, vocab: Dict[str, int], pair: Tuple[str, str]) -> Dict[str, int]:
"""
Объединяет пару символов в словаре.
Args:
vocab: Исходный словарь
pair: Пара для объединения
Returns:
Dict[str, int]: Обновленный словарь
"""
new_vocab = {}
bigram = re.compile(r'(?<!\\S)' + re.escape(pair[0]) + r' ' + re.escape(pair[1]) + r'(?!\\S)')
replacement = pair[0] + pair[1]
for word in vocab:
new_word = bigram.sub(replacement, word)
new_vocab[new_word] = vocab[word]
return new_vocab
def _build_final_vocab(self):
"""Строит финальный словарь токенизатора."""
# Собираем все уникальные токены из мерджей
all_tokens = set()
# Добавляем специальные токены
all_tokens.update([self.pad_token, self.unk_token, self.bos_token, self.eos_token])
# Добавляем токены из мерджей
for pair in self.merges:
all_tokens.update(pair)
# Создаем словарь
for i, token in enumerate(sorted(all_tokens)):
self.vocab[token] = i
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
self.vocab_size = len(self.vocab)
# Обновляем ID специальных токенов
self.pad_token_id = self.vocab.get(self.pad_token)
self.unk_token_id = self.vocab.get(self.unk_token)
self.bos_token_id = self.vocab.get(self.bos_token)
self.eos_token_id = self.vocab.get(self.eos_token)
def encode(self, text: str, **kwargs) -> List[int]:
"""
Кодирует текст в последовательность токенов.
Args:
text: Входной текст
**kwargs: Дополнительные параметры
- add_special_tokens: Добавлять специальные токены
Returns:
List[int]: Список идентификаторов токенов
"""
add_special_tokens = kwargs.get('add_special_tokens', False)
# Токенизация текста
tokens = self.compiled_pattern.findall(text)
# Применяем BPE к каждому токену
bpe_tokens = []
for token in tokens:
# Преобразуем токен в BPE представление
bpe_token = self._apply_bpe(token)
bpe_tokens.extend(bpe_token)
# Конвертируем в ID
token_ids = []
for token in bpe_tokens:
token_id = self.vocab.get(token, self.unk_token_id)
if token_id is not None:
token_ids.append(token_id)
# Добавляем специальные токены если нужно
if add_special_tokens:
if self.bos_token_id is not None:
token_ids.insert(0, self.bos_token_id)
if self.eos_token_id is not None:
token_ids.append(self.eos_token_id)
return token_ids
def _apply_bpe(self, token: str) -> List[str]:
"""
Применяет BPE к одному токену.
Args:
token: Входной токен
Returns:
List[str]: Список BPE токенов
"""
# Простая реализация - в реальной реализации нужно применять обученные мерджи
word = token + '</w>'
tokens = [word[i:i+1] for i in range(len(word))]
# Применяем мерджи (упрощенная версия)
# В полной реализации нужно применять все обученные мерджи
for pair in self.merges:
i = 0
while i < len(tokens) - 1:
if tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
tokens[i] = tokens[i] + tokens[i + 1]
del tokens[i + 1]
else:
i += 1
return tokens
def decode(self, tokens: List[int], **kwargs) -> str:
"""
Декодирует последовательность токенов в текст.
Args:
tokens: Список идентификаторов токенов
**kwargs: Дополнительные параметры
- skip_special_tokens: Пропускать специальные токены
Returns:
str: Декодированный текст
"""
skip_special_tokens = kwargs.get('skip_special_tokens', True)
# Конвертируем ID в токены
token_strings = []
for token_id in tokens:
token = self.inverse_vocab.get(token_id, self.unk_token)
# Пропускаем специальные токены если нужно
if skip_special_tokens and token in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]:
continue
token_strings.append(token)
# Объединяем токены в текст
text = ''.join(token_strings)
# Убираем маркер конца слова
text = text.replace('</w>', ' ')
return text.strip()
def save(self, filepath: str):
"""
Сохраняет BPE токенизатор в файл.
Args:
filepath: Путь для сохранения
"""
import json
config = {
'vocab': self.vocab,
'merges': {f"{k[0]} {k[1]}": v for k, v in self.merges.items()},
'vocab_size': self.vocab_size,
'pad_token': self.pad_token,
'unk_token': self.unk_token,
'bos_token': self.bos_token,
'eos_token': self.eos_token,
'pattern': self.pattern,
'tokenizer_type': self.__class__.__name__
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
def load(cls, filepath: str):
"""
Загружает BPE токенизатор из файла.
Args:
filepath: Путь к файлу
Returns:
BPETokenizer: Загруженный токенизатор
"""
import json
with open(filepath, 'r', encoding='utf-8') as f:
config = json.load(f)
tokenizer = cls()
tokenizer.vocab = config['vocab']
tokenizer.vocab_size = config['vocab_size']
tokenizer.pad_token = config['pad_token']
tokenizer.unk_token = config['unk_token']
tokenizer.bos_token = config['bos_token']
tokenizer.eos_token = config['eos_token']
tokenizer.pattern = config.get('pattern', tokenizer.pattern)
tokenizer.compiled_pattern = re.compile(tokenizer.pattern, re.UNICODE)
# Восстанавливаем мерджи
merges = config.get('merges', {})
tokenizer.merges = {}
for k, v in merges.items():
parts = k.split()
if len(parts) == 2:
tokenizer.merges[(parts[0], parts[1])] = v
# Создаем обратный словарь
tokenizer.inverse_vocab = {v: k for k, v in tokenizer.vocab.items()}
# Обновляем ID специальных токенов
tokenizer.pad_token_id = tokenizer.vocab.get(tokenizer.pad_token)
tokenizer.unk_token_id = tokenizer.vocab.get(tokenizer.unk_token)
tokenizer.bos_token_id = tokenizer.vocab.get(tokenizer.bos_token)
tokenizer.eos_token_id = tokenizer.vocab.get(tokenizer.eos_token)
return tokenizer
# Упрощенная версия для быстрого старта
class SimpleBPETokenizer(BPETokenizer):
"""
Упрощенная версия BPE токенизатора для демонстрации.
"""
def train(self, texts: List[str], vocab_size: int = 1000, **kwargs):
"""Упрощенное обучение для демонстрации."""
# Инициализация базового словаря
self._initialize_vocab()
# Добавляем базовые токены
special_tokens = [self.pad_token, self.unk_token, self.bos_token, self.eos_token]
self.add_special_tokens(special_tokens)
# Простая реализация - собираем все символы
all_chars = set()
for text in texts:
all_chars.update(text)
# Добавляем символы в словарь
for char in sorted(all_chars):
if char not in self.vocab:
self.vocab[char] = len(self.vocab)
self.inverse_vocab = {v: k for k, v in self.vocab.items()}
self.vocab_size = len(self.vocab)
# Обновляем ID специальных токенов
self.pad_token_id = self.vocab.get(self.pad_token)
self.unk_token_id = self.vocab.get(self.unk_token)
self.bos_token_id = self.vocab.get(self.bos_token)
self.eos_token_id = self.vocab.get(self.eos_token)
def encode(self, text: str, **kwargs) -> List[int]:
"""Упрощенное кодирование - разбиваем на символы."""
add_special_tokens = kwargs.get('add_special_tokens', False)
token_ids = []
for char in text:
token_id = self.vocab.get(char, self.unk_token_id)
if token_id is not None:
token_ids.append(token_id)
if add_special_tokens:
if self.bos_token_id is not None:
token_ids.insert(0, self.bos_token_id)
if self.eos_token_id is not None:
token_ids.append(self.eos_token_id)
return token_ids
def decode(self, tokens: List[int], **kwargs) -> str:
"""Упрощенное декодирование."""
skip_special_tokens = kwargs.get('skip_special_tokens', True)
chars = []
for token_id in tokens:
char = self.inverse_vocab.get(token_id, self.unk_token)
if skip_special_tokens and char in [self.pad_token, self.unk_token, self.bos_token, self.eos_token]:
continue
chars.append(char)
return ''.join(chars)

View File

@@ -10,16 +10,54 @@ from .base_tokenizer import BaseTokenizer
class BPETokenizer(BaseTokenizer):
"""
BPE токенизатор для обработки текста.
BpeTokenizer — реализация токенизатора на алгоритме byte pair encoding (BPE).
Реализует алгоритм Byte Pair Encoding для создания субсловных токенов.
Использует вашу реализацию BPE.
Назначение:
-----------
- Преобразует открытый текст (строки, bytes) в последовательность числовых токенов для подачи в LLM и обратно.
- Разбивает текст на сабслова (байтовые пары), эффективно кодируя редкие слова длинными последовательностями, а частые — единичными токенами.
- Является стандартом де-факто в современных языковых моделях (GPT, LLaMA, BLOOM, Mistral, HuggingFace).
Примеры использования:
>>> tokenizer = BPETokenizer()
>>> tokenizer.train(["пример текста для обучения"], vocab_size=1000)
>>> tokens = tokenizer.encode("новый текст")
Как работает BPE:
-----------------
1. Строится словарь из наиболее популярных пар символов/субстрок.
2. Текст замещается наиболее длинными subword-подстроками из vocabulary (жадно).
3. Итог: многомиллионное лексическое пространство сокращается до компактного набора subword pieces.
Особенности алгоритма:
----------------------
- Отлично работает на всех языках, включая rare/compound/inflectable.
- Гибко масштабируется под размер итогового словаря/token space.
- Обычно хранит mapping (str/bytes → int и int → str/bytes) в JSON или словарном файле.
- Может использовать кастомные сепараторы, handle unknown.
Аргументы конструктора:
-----------------------
vocab_path: str
Путь к файлу BPE vocabulary (JSON, txt, в зависимости от реализации).
merges_path: str, optional
Путь к списку merge-правил (если используется блочное файловое раздельное хранение).
unk_token: str, optional
Токен для неизвестных последовательностей (по дефолту '[UNK]' или '<unk>').
pad_token, bos_token, eos_token: str, optional
Special tokens, если нужны для вашей архитектуры.
lowercase: bool, optional
Приводить ли текст к нижнему регистру перед токенизацией.
Пример:
-------
>>> tokenizer = BpeTokenizer(vocab_path=\"bpe_vocab.json\")
>>> tokens = tokenizer.encode(\"Hello, world!\")
>>> print(tokens) # [15496, 11, ...]
>>> text = tokenizer.decode(tokens)
>>> print(text) # 'Hello, world!'
References:
-----------
- Sennrich et al, \"Neural Machine Translation of Rare Words with Subword Units\", 2015: https://arxiv.org/abs/1508.07909
- GPT-2 tokenization: https://github.com/openai/gpt-2
- HuggingFace tokenizers overview: https://huggingface.co/docs/tokenizers/index
- Visually: https://guillaume-be.github.io/2021-05-21/byte-pair-encoding/
"""
def __init__(self):
@@ -61,7 +99,10 @@ class BPETokenizer(BaseTokenizer):
break # нет пар — выходим
# Находим самую частую пару (в случае равенства — та, что встретилась первой)
most_frequent_pair = max(pair_freq.items(), key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])))[0]
most_frequent_pair = max(
pair_freq.items(),
key=lambda x: (x[1], -self._pair_first_index(sequence, x[0])),
)[0]
# Создаем новый токен
new_token = most_frequent_pair[0] + most_frequent_pair[1]
@@ -71,7 +112,10 @@ class BPETokenizer(BaseTokenizer):
new_sequence = []
while i < len(sequence):
if i < len(sequence) - 1 and (sequence[i], sequence[i + 1]) == most_frequent_pair:
if (
i < len(sequence) - 1
and (sequence[i], sequence[i + 1]) == most_frequent_pair
):
new_sequence.append(new_token)
i += 2 # пропускаем два символа — заменённую пару
else:
@@ -86,7 +130,10 @@ class BPETokenizer(BaseTokenizer):
self.vocab_size = len(self.vocab)
# Добавляем специальные токены если указаны
special_tokens = kwargs.get('special_tokens', [self.pad_token, self.unk_token, self.bos_token, self.eos_token])
special_tokens = kwargs.get(
"special_tokens",
[self.pad_token, self.unk_token, self.bos_token, self.eos_token],
)
self.add_special_tokens(special_tokens)
def _pair_first_index(self, sequence, pair):
@@ -94,21 +141,27 @@ class BPETokenizer(BaseTokenizer):
for i in range(len(sequence) - 1):
if (sequence[i], sequence[i + 1]) == pair:
return i
return float('inf') # если пара не найдена (в теории не должно случиться)
return float("inf") # если пара не найдена (в теории не должно случиться)
def encode(self, text: str, **kwargs) -> List[int]:
"""
Кодирует текст в последовательность токенов.
Токенизирует входной текст в список числовых токенов (индексов).
Args:
text: Входной текст
**kwargs: Дополнительные параметры
- add_special_tokens: Добавлять специальные токены
-----
text: str
Входная строка/текст для токенизации.
Returns:
List[int]: Список идентификаторов токенов
--------
List[int] — последовательность индексов из vocabulary.
Пример:
-------
>>> ids = tokenizer.encode(\"The quick brown fox\")
>>> print(ids)
"""
add_special_tokens = kwargs.get('add_special_tokens', False)
add_special_tokens = kwargs.get("add_special_tokens", False)
# 1. Разбиваем текст на токены-символы
sequence = list(text)
@@ -119,7 +172,9 @@ class BPETokenizer(BaseTokenizer):
while i < len(text):
# 3.1 Найти все токены в словаре, начинающиеся с text[i]
start_char = text[i]
result = [token for token in self.vocab_list if token.startswith(start_char)]
result = [
token for token in self.vocab_list if token.startswith(start_char)
]
# 3.2 Выбрать самый длинный подходящий токен
find_token = self._find_max_matching_token(text[i:], result)
if find_token is None:
@@ -164,29 +219,44 @@ class BPETokenizer(BaseTokenizer):
def decode(self, tokens: List[int], **kwargs) -> str:
"""
Декодирует последовательность токенов в текст.
Декодирует последовательность токенов обратно в текстовую строку.
Args:
tokens: Список идентификаторов токенов
**kwargs: Дополнительные параметры
- skip_special_tokens: Пропускать специальные токены
-----
ids: List[int]
Список токен-индексов для распаковки.
Returns:
str: Декодированный текст
--------
text: str
Оригинальный (или приближённый) раскодированный текст.
Пример:
-------
>>> tokens = [15496, 11, 318, ...]
>>> text = tokenizer.decode(tokens)
"""
skip_special_tokens = kwargs.get('skip_special_tokens', True)
skip_special_tokens = kwargs.get("skip_special_tokens", True)
# Фильтруем специальные токены если нужно
if skip_special_tokens:
tokens = [tid for tid in tokens if tid not in [
self.pad_token_id, self.unk_token_id, self.bos_token_id, self.eos_token_id
]]
tokens = [
tid
for tid in tokens
if tid
not in [
self.pad_token_id,
self.unk_token_id,
self.bos_token_id,
self.eos_token_id,
]
]
# Конвертируем ID в токены
token_strings = self._ids_to_tokens(tokens)
# Объединяем токены в текст
return ''.join(token_strings)
return "".join(token_strings)
def _ids_to_tokens(self, ids: List[int]) -> List[str]:
"""Конвертирует список Ids в их tokens"""
@@ -211,18 +281,18 @@ class BPETokenizer(BaseTokenizer):
merges_serializable = {f"{k[0]},{k[1]}": v for k, v in self.merges.items()}
config = {
'vocab': self.vocab,
'vocab_size': self.vocab_size,
'pad_token': self.pad_token,
'unk_token': self.unk_token,
'bos_token': self.bos_token,
'eos_token': self.eos_token,
'tokenizer_type': self.__class__.__name__,
'merges': merges_serializable,
'vocab_list': self.vocab_list
"vocab": self.vocab,
"vocab_size": self.vocab_size,
"pad_token": self.pad_token,
"unk_token": self.unk_token,
"bos_token": self.bos_token,
"eos_token": self.eos_token,
"tokenizer_type": self.__class__.__name__,
"merges": merges_serializable,
"vocab_list": self.vocab_list,
}
with open(filepath, 'w', encoding='utf-8') as f:
with open(filepath, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
@classmethod
@@ -238,23 +308,23 @@ class BPETokenizer(BaseTokenizer):
"""
import json
with open(filepath, 'r', encoding='utf-8') as f:
with open(filepath, "r", encoding="utf-8") as f:
config = json.load(f)
# Создаем экземпляр токенизатора
tokenizer = cls()
tokenizer.vocab = config['vocab']
tokenizer.vocab_size = config['vocab_size']
tokenizer.pad_token = config['pad_token']
tokenizer.unk_token = config['unk_token']
tokenizer.bos_token = config['bos_token']
tokenizer.eos_token = config['eos_token']
tokenizer.vocab_list = config['vocab_list']
tokenizer.vocab = config["vocab"]
tokenizer.vocab_size = config["vocab_size"]
tokenizer.pad_token = config["pad_token"]
tokenizer.unk_token = config["unk_token"]
tokenizer.bos_token = config["bos_token"]
tokenizer.eos_token = config["eos_token"]
tokenizer.vocab_list = config["vocab_list"]
# Восстанавливаем кортежи из строк
tokenizer.merges = {}
for k, v in config['merges'].items():
parts = k.split(',')
for k, v in config["merges"].items():
parts = k.split(",")
if len(parts) == 2:
tokenizer.merges[(parts[0], parts[1])] = v
@@ -275,4 +345,5 @@ class SimpleBPETokenizer(BPETokenizer):
Упрощенная версия BPE токенизатора для демонстрации.
Наследует вашу реализацию, но может быть упрощена при необходимости.
"""
pass

View File

@@ -1,142 +0,0 @@
import torch
from torch.utils.data import Dataset
from typing import List, Any
class TextDataset(Dataset):
"""
Простой датасет для языкового моделирования (LLM).
Работает с любым токенизатором, реализующим интерфейс BaseTokenizer.
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
"""
Инициализация датасета.
Args:
texts: Список текстов для обучения
tokenizer: Токенизатор с методами encode/decode
block_size: Максимальная длина последовательности
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
for text in texts:
# Кодируем текст в токены
input_ids = tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > block_size:
input_ids = input_ids[:block_size]
else:
# Дополняем pad_token_id
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}
class StreamingTextDataset(Dataset):
"""
Датасет для потоковой обработки больших текстов.
Токенизация происходит на лету, что экономит память.
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128):
self.texts = texts
self.tokenizer = tokenizer
self.block_size = block_size
# Получаем pad_token_id из токенизатора
self.pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
# Токенизация на лету
input_ids = self.tokenizer.encode(text, add_special_tokens=False)
# Обрезаем или дополняем до нужной длины
if len(input_ids) > self.block_size:
input_ids = input_ids[:self.block_size]
else:
input_ids = input_ids + [self.pad_token_id] * (self.block_size - len(input_ids))
input_ids = torch.tensor(input_ids, dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}
class TextDatasetWithSpecialTokens(TextDataset):
"""
Расширенная версия TextDataset с поддержкой специальных токенов.
"""
def __init__(self, texts: List[str], tokenizer: Any, block_size: int = 128,
add_bos: bool = False, add_eos: bool = False):
"""
Args:
texts: Список текстов
tokenizer: Токенизатор
block_size: Максимальная длина
add_bos: Добавлять токен начала последовательности
add_eos: Добавлять токен конца последовательности
"""
self.examples = []
self.tokenizer = tokenizer
self.block_size = block_size
self.add_bos = add_bos
self.add_eos = add_eos
for text in texts:
# Кодируем с специальными токенами
input_ids = tokenizer.encode(
text,
add_special_tokens=True,
add_bos_token=add_bos,
add_eos_token=eos
)
# Учитываем специальные токены при обрезке/дополнении
effective_block_size = block_size
if add_bos:
effective_block_size -= 1
if add_eos:
effective_block_size -= 1
if len(input_ids) > effective_block_size:
input_ids = input_ids[:effective_block_size]
# Добавляем специальные токены если нужно
if add_bos and hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None:
input_ids = [tokenizer.bos_token_id] + input_ids
if add_eos and hasattr(tokenizer, 'eos_token_id') and tokenizer.eos_token_id is not None:
input_ids = input_ids + [tokenizer.eos_token_id]
# Дополняем до полной длины
pad_token_id = getattr(tokenizer, 'pad_token_id', 0)
if len(input_ids) < block_size:
input_ids = input_ids + [pad_token_id] * (block_size - len(input_ids))
self.examples.append(input_ids)
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids = torch.tensor(self.examples[idx], dtype=torch.long)
labels = input_ids.clone()
return {"input_ids": input_ids, "labels": labels}

View File

@@ -1,8 +1,71 @@
"""
Модуль оптимизации для обучения нейронных сетей.
В данном модуле реализована функция выбора и инициализации оптимизаторов, наиболее популярных при обучении глубоких нейросетей:
- AdamW
- Adam
- SGD
Теоретическое обоснование:
--------------------------
Задача оптимизации в обучении нейросети заключается в минимизации функции потерь (Loss) по параметрам модели W. Современные методы базируются на стохастическом градиентном спуске (SGD), а также на его адаптивных модификациях (Adam, AdamW).
**SGD** (Stochastic Gradient Descent) — стохастический градиентный спуск:
W_{t+1} = W_t - \eta \nabla_W L(W_t)
Здесь \eta — шаг обучения, \nabla_W — градиент по параметрам. SGD позволяет случайно выбирать подмножество обучающих данных для каждой итерации, что ускоряет процесс и уменьшает избыточную корреляцию между примерами.
**Adam** (Adaptive Moment Estimation) — адаптивный алгоритм, который использует скользящую среднюю не только градиентов, но и их квадратов:
m_t = \beta_1 m_{t-1} + (1-\beta_1) \nabla_W L(W_t)
v_t = \beta_2 v_{t-1} + (1-\beta_2) (\nabla_W L(W_t))^2
W_{t+1} = W_t - \eta m_t/(\sqrt{v_t}+\epsilon)
Где \beta_1, \beta_2 — коэффициенты экспоненциального сглаживания.
**AdamW** — модификация Adam, в которой weight decay (имплицитная L2-регуляризация) вводится корректно, отдельно от шага градиента, что улучшает обобщающую способность моделей:
W_{t+1} = W_t - \eta [ m_t/(\sqrt{v_t}+\epsilon) + \lambda W_t ]
Где \lambda — коэффициент weight decay.
Детальное описание: https://arxiv.org/abs/1711.05101
Пример использования:
---------------------
>>> optimizer = get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw")
>>> for batch in dataloader:
... loss = model(batch)
... loss.backward()
... optimizer.step()
... optimizer.zero_grad()
"""
import torch.optim as optim
def get_optimizer(model, lr=3e-4, weight_decay=0.01, optimizer_type="adamw"):
"""
Возвращает оптимизатор для обучения модели.
Фабричная функция для создания оптимизатора PyTorch по выбранному типу.
Параметры
---------
model : torch.nn.Module
Модель, параметры которой требуется оптимизировать.
lr : float, по умолчанию 3e-4
Шаг обучения (learning rate).
weight_decay : float, по умолчанию 0.01
Коэффициент weight decay (L2-регуляризации).
optimizer_type : str, по умолчанию 'adamw'
Тип оптимизатора: 'adamw', 'adam' или 'sgd'.
Возвращаемое значение
---------------------
torch.optim.Optimizer
Объект-оптимизатор, готовый к использованию.
Исключения
----------
ValueError: Если передан неизвестный тип оптимизатора.
Пример использования:
---------------------
>>> optimizer = get_optimizer(model, lr=1e-3, optimizer_type='sgd')
"""
if optimizer_type.lower() == "adamw":
return optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

View File

@@ -1,13 +1,66 @@
"""
Модуль для управления динамикой шага обучения (learning rate scheduling) при обучении нейронных сетей.
Теоретическое обоснование:
--------------------------
Плавная динамика шага обучения существенно влияет на сходимость и итоговое качество моделей. Введение этапа "разогрева" (warmup) — техники, при которой шаг обучения начинается с нуля и постепенно увеличивается до целевого значения, снижает вероятность неустойчивых градиентов на старте обучения. Подобная стратегия показала свою эффективность для крупных нейронных сетей, особенно в трансформерах (Vaswani et al, 2017, https://arxiv.org/abs/1706.03762).
Линейный scheduler с warmup задаёт динамику learning rate по формуле:
- если current_step < num_warmup_steps:
lr = lr_init * (current_step / num_warmup_steps)
- иначе:
lr = lr_init * max(0, (num_training_steps - current_step) / (num_training_steps - num_warmup_steps))
Пример использования:
---------------------
>>> optimizer = get_optimizer(model, lr=3e-4)
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
>>> for step in range(num_training_steps):
... optimizer.step()
... scheduler.step()
"""
from torch.optim.lr_scheduler import LambdaLR
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps):
"""
Линейный планировщик обучения с warmup.
Создаёт линейный планировщик изменения шага обучения (learning rate) с этапом warmup для оптимизатора PyTorch.
Аргументы
---------
optimizer : torch.optim.Optimizer
Оптимизатор, для которого применяется scheduler.
num_warmup_steps : int
Количество шагов разогрева (warmup) — начиная с нулевого шага и плавного увеличения lr до номинального значения.
num_training_steps : int
Общее количество шагов (эпох/итераций) обучения модели.
Возвращаемое значение
---------------------
torch.optim.lr_scheduler.LambdaLR
Планировщик lr, который следует вызывать после каждого optimizer.step() во время обучения.
Теоретическая справка
---------------------
Такой scheduler позволяет повысить стабильность и устойчивость обучения крупных моделей (особенно трансформеров), предотвращая резкие скачки градиентов в начале.
Пример:
-------
>>> optimizer = get_optimizer(model, lr=3e-4)
>>> scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=10000)
>>> for step in range(num_training_steps):
... optimizer.step()
... scheduler.step()
"""
def lr_lambda(current_step):
# Линейный рост lr на этапе разогрева
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))
# Линейное затухание lr после разогрева
return max(
0.0,
float(num_training_steps - current_step)
/ float(max(1, num_training_steps - num_warmup_steps)),
)
return LambdaLR(optimizer, lr_lambda)

View File

@@ -1,3 +1,22 @@
"""
Модуль для организации процесса обучения больших языковых моделей (LLM).
Научное и техническое обоснование
----------------------------------
Эффективное обучение современных трансформеров (GPT, LLaMA, Mistral и др.) опирается на принципы языкового моделирования (Language Modeling):
- Предсказание вероятности следующего токена на основе предыдущих.
- Использование функции потерь кросс-энтропии (cross-entropy) с маскированием паддингов.
- Циклы обратного распространения ошибки (backpropagation), оптимизационные алгоритмы (например, AdamW), управление шагом обучения (scheduler с warmup), обрезка градиентов (grad clipping).
Реализация объединяет лучшие практики обучения LLM, универсальный API к моделям, датасетам, оптимизаторам и lr-схемам.
Подробнее: Vaswani et al. "Attention is All You Need" (2017), Radford et al. "Language Models are Unsupervised Multitask Learners" (2019)
Пример использования
--------------------
>>> trainer = Trainer(model, train_dataset, val_dataset, lr=3e-4, batch_size=8, num_epochs=3, warmup_steps=100)
>>> trainer.train()
"""
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
@@ -5,15 +24,75 @@ from tqdm import tqdm
from llm.training.optimizer import get_optimizer
from llm.training.scheduler import get_linear_schedule_with_warmup
class Trainer:
"""
Универсальный класс обучения LLM (GPT, LLaMA, Mistral и т.д.)
Универсальный и расширяемый класс для обучения больших языковых моделей (Large Language Models, LLM).
Поддерживаются архитектуры семейства GPT, LLaMA, Mistral и другие автогрессивные модели.
Объединяет:
- Тренировку по задаче языкового моделирования (Causal LM)
- Cross-entropy loss с автоматическим сдвигом логитов/меток
- Поддержку Grad Clipping, Scheduler, Validation
- Унифицированный даталоадер, автоматический выбор устройства (CPU/GPU)
Атрибуты
--------
model : torch.nn.Module
Модель для обучения языковому моделированию
train_loader : torch.utils.data.DataLoader
Даталоадер обучающего набора
val_loader : torch.utils.data.DataLoader или None
Даталоадер валидационного набора (если задан)
optimizer : torch.optim.Optimizer
Оптимизатор параметров модели
scheduler : torch.optim.lr_scheduler.LambdaLR
Планировщик learning rate (инициализируется в train)
device : torch.device
Устройство (CPU или CUDA), куда помещается модель
num_epochs : int
Количество эпох обучения
warmup_steps : int
Число шагов warmup для scheduler
"""
def __init__(self, model, train_dataset, val_dataset=None, lr=3e-4, batch_size=8, num_epochs=3, warmup_steps=100):
def __init__(
self,
model,
train_dataset,
val_dataset=None,
lr=3e-4,
batch_size=8,
num_epochs=3,
warmup_steps=100,
):
"""
Инициализация обучающего класса Trainer.
Аргументы
---------
model : torch.nn.Module
Модель для обучения (например, GPT, LLaMA, Mistral).
train_dataset : torch.utils.data.Dataset
Обучающий датасет с полями input_ids и labels.
val_dataset : torch.utils.data.Dataset, optional
Валидационный датасет для контроля качества обучения.
lr : float, default=3e-4
Начальный шаг обучения.
batch_size : int, default=8
Размер обучающего мини-батча.
num_epochs : int, default=3
Количество эпох обучения.
warmup_steps : int, default=100
Количество шагов разогрева (warmup) learning rate.
"""
self.model = model
self.train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(val_dataset, batch_size=batch_size) if val_dataset else None
self.train_loader = DataLoader(
train_dataset, batch_size=batch_size, shuffle=True
)
self.val_loader = (
DataLoader(val_dataset, batch_size=batch_size) if val_dataset else None
)
self.optimizer = get_optimizer(model, lr=lr)
self.scheduler = None
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -23,44 +102,74 @@ class Trainer:
def compute_lm_loss(self, logits, labels):
"""
Вычисляет loss для языкового моделирования.
Сдвигает логиты и метки для предсказания следующего токена.
Вычисляет функцию потерь (loss) для задачи автогрессивного языкового моделирования.
Производит сдвиг логитов и меток: предсказания делаются для следующего токена.
Используется кросс-энтропия (CrossEntropyLoss), что соответствует максимизации логарифма правдоподобия:
L = -log P(w_{t+1} | w_1,...,w_t)
Аргументы
---------
logits : torch.Tensor
Логиты модели: (batch_size, seq_len, vocab_size)
labels : torch.Tensor
Правильные метки: (batch_size, seq_len)
Возвращаемое значение
---------------------
loss : torch.Tensor
Средний loss по batch.
"""
# Сдвигаем логиты и метки для языкового моделирования
# Сдвигаем логиты и метки для языкового моделирования (автогрессия)
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Вычисляем cross-entropy loss
# CrossEntropyLoss (игнорируем паддинги: ignore_index=-100)
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100 # Игнорируем padding tokens
ignore_index=-100, # Padding токены не участвуют в loss
)
return loss
def train(self):
"""
Запускает процесс обучения модели по заданному числу эпох.
В процессе:
- Применяет optimizer, scheduler с warmup и decay, grad clipping (обрезка градиентов)
- Вызывает функцию потерь для языкового моделирования
- Показывает динамику процесса (tqdm)
- После каждой эпохи возможно проведение валидации
Параметры задаются на этапе инициализации Trainer.
"""
total_steps = len(self.train_loader) * self.num_epochs
self.scheduler = get_linear_schedule_with_warmup(self.optimizer, self.warmup_steps, total_steps)
self.scheduler = get_linear_schedule_with_warmup(
self.optimizer, self.warmup_steps, total_steps
)
self.loss_history = [] # добавлено: лог средних потерь
for epoch in range(self.num_epochs):
self.model.train()
total_loss = 0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}")
progress_bar = tqdm(
self.train_loader, desc=f"Epoch {epoch+1}/{self.num_epochs}"
)
for batch in progress_bar:
self.optimizer.zero_grad()
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
# Универсально обрабатываем выход (tuple/logits)
# Универсально обрабатываем выходы модели: tuple или просто tensor (logits)
outputs = self.model(input_ids)
if isinstance(outputs, tuple):
logits = outputs[0]
else:
logits = outputs
# Trainer вычисляет loss
# Вычисляем loss автогрессивной LM-задачи
loss = self.compute_lm_loss(logits, labels)
loss.backward()
@@ -72,12 +181,19 @@ class Trainer:
progress_bar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(self.train_loader)
self.loss_history.append(avg_loss) # добавлено: запоминаем loss
print(f"Epoch {epoch+1} finished — avg loss: {avg_loss:.4f}")
if self.val_loader:
self.evaluate()
def evaluate(self):
"""
Оценивает модель на валидационном датасете (если задан).
В режиме eval() модели отключается dropout и все стохастические элементы.
Возвращает среднее значение функции потерь (loss) по всему validation set.
"""
self.model.eval()
total_loss = 0

View File

@@ -58,7 +58,7 @@ def gpt_config(vocab_size, embed_dim, num_heads, num_layers):
"num_heads": num_heads,
"num_layers": num_layers,
"max_position_embeddings": 1024,
"dropout": 0.1
"dropout": 0.1,
}
@@ -68,12 +68,14 @@ def random_inputs(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len))
return input_ids
@pytest.fixture
def random_float_inputs(batch_size, seq_len, embed_dim):
"""Generate random floating point input tensors for testing feed forward."""
inputs = torch.randn(batch_size, seq_len, embed_dim)
return inputs
@pytest.fixture
def random_embeddings(batch_size, seq_len, embed_dim):
"""Generate random embedding tensors for testing attention modules."""

View File

@@ -0,0 +1,65 @@
import torch
import pytest
from llm.core.cached_decoder import CachedDecoder
from llm.core.feed_forward import FeedForward
@pytest.fixture
def decoder_config():
return dict(
num_heads=4,
emb_size=32,
head_size=8,
feed_forward_layer=FeedForward(emb_size=32, dropout=0.1, activation="gelu"),
max_seq_len=64,
dropout=0.1
)
def test_cached_decoder_init(decoder_config):
model = CachedDecoder(**decoder_config)
assert model is not None
# Main attention block is usually stored as _heads or _attention (which itself includes _q _k _v)
assert hasattr(model, '_heads') or hasattr(model, '_attention')
assert hasattr(model, '_ff') or hasattr(model, 'feed_forward_layer')
def test_cached_decoder_forward_shape(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 3, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=True)
assert output.shape == (batch, seq_len, emb_size)
assert cache is not None
def test_cached_decoder_forward_no_cache(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 12, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=False)
assert output.shape == (batch, seq_len, emb_size)
assert cache is None
def test_cached_decoder_error_on_long_seq(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 1, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
with pytest.raises(ValueError):
model(x)
def test_cached_decoder_backward(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 7, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
output, cache = model(x)
loss = output.sum()
loss.backward()
assert x.grad is not None
def test_cached_decoder_kv_cache_chain(decoder_config):
model = CachedDecoder(**decoder_config)
batch, seq_len, emb_size = 1, 4, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
# Первый проход — кэша нет
_, cache = model(x, use_cache=True)
# Второй проход — передаём кэш, добавляем еще токен:
next_x = torch.randn(batch, 1, emb_size)
_, cache2 = model(next_x, use_cache=True, cache=cache)
assert cache2 is not None

View File

@@ -17,10 +17,10 @@ class TestFeedForward:
assert ff is not None
# Check internal layers
assert hasattr(ff, '_layer1')
assert hasattr(ff, '_layer2')
assert hasattr(ff, '_activation')
assert hasattr(ff, '_dropout')
assert hasattr(ff, "_layer1")
assert hasattr(ff, "_layer2")
assert hasattr(ff, "_activation")
assert hasattr(ff, "_dropout")
# Check layer dimensions
expected_hidden_dim = embed_dim * 4 # Default expansion factor
@@ -101,10 +101,12 @@ class TestFeedForward:
# Check that gradients are computed for learnable parameters
assert ff._layer1.weight.grad is not None
assert ff._layer2.weight.grad is not None
assert not torch.allclose(ff._layer1.weight.grad,
torch.zeros_like(ff._layer1.weight.grad))
assert not torch.allclose(ff._layer2.weight.grad,
torch.zeros_like(ff._layer2.weight.grad))
assert not torch.allclose(
ff._layer1.weight.grad, torch.zeros_like(ff._layer1.weight.grad)
)
assert not torch.allclose(
ff._layer2.weight.grad, torch.zeros_like(ff._layer2.weight.grad)
)
def test_device_consistency(self, embed_dim, random_float_inputs, device):
"""Test that FeedForward works on correct device."""
@@ -167,11 +169,19 @@ class TestFeedForward:
ff = FeedForward(embed_dim)
# Check that weights are not all zeros
assert not torch.allclose(ff._layer1.weight, torch.zeros_like(ff._layer1.weight))
assert not torch.allclose(ff._layer2.weight, torch.zeros_like(ff._layer2.weight))
assert not torch.allclose(
ff._layer1.weight, torch.zeros_like(ff._layer1.weight)
)
assert not torch.allclose(
ff._layer2.weight, torch.zeros_like(ff._layer2.weight)
)
# Check that biases are not all zeros (they should be initialized with some values)
if ff._layer1.bias is not None:
assert not torch.allclose(ff._layer1.bias, torch.zeros_like(ff._layer1.bias))
assert not torch.allclose(
ff._layer1.bias, torch.zeros_like(ff._layer1.bias)
)
if ff._layer2.bias is not None:
assert not torch.allclose(ff._layer2.bias, torch.zeros_like(ff._layer2.bias))
assert not torch.allclose(
ff._layer2.bias, torch.zeros_like(ff._layer2.bias)
)

View File

@@ -0,0 +1,60 @@
import torch
import pytest
from llm.core.geglu import GeGLU
@pytest.fixture
def geglu():
return GeGLU(emb_size=16, dropout=0.1)
def test_forward_shape(geglu):
x = torch.randn(2, 5, 16)
y = geglu(x)
assert y.shape == x.shape
def test_forward_no_batch(geglu):
x = torch.randn(1, 16)
y = geglu(x.unsqueeze(0))
assert y.shape == (1, 1, 16)
@pytest.mark.skip(reason="float16 not supported without parameter casting")
def test_forward_dtype_fp16():
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 4, 8).half()
y = geglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_forward_no_dropout():
geglu = GeGLU(emb_size=4, dropout=0.0)
x = torch.randn(3, 2, 4)
y = geglu(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()
def test_gradient_flow(geglu):
x = torch.randn(3, 8, 16, requires_grad=True)
y = geglu(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_forward_repeatability():
torch.manual_seed(42)
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(3, 2, 8)
y1 = geglu(x)
torch.manual_seed(42)
geglu2 = GeGLU(emb_size=8, dropout=0.0)
x2 = torch.randn(3, 2, 8)
y2 = geglu2(x2)
assert torch.allclose(y1, y2, atol=1e-5)
def test_edge_small_large():
geglu = GeGLU(emb_size=2, dropout=0.0)
x = torch.randn(2, 2, 2)
y = geglu(x)
assert y.shape == x.shape
geglu = GeGLU(emb_size=256, dropout=0.0)
x = torch.randn(1, 1, 256)
y = geglu(x)
assert y.shape == x.shape

View File

@@ -0,0 +1,46 @@
import torch
import pytest
from llm.core.gelu import GELU
def test_gelu_shapes_and_dtype():
gelu = GELU()
x = torch.randn(4, 16, 8)
y = gelu(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_gelu_known_values():
gelu = GELU()
x = torch.tensor([-3.0, 0.0, 3.0])
y = gelu(x)
# Сравнение с PyTorch F.gelu (которая использует точный алгоритм)
y_ref = torch.nn.functional.gelu(x)
diff = (y - y_ref).abs().max().item()
assert diff < 5e-3, f"Max difference {diff} exceeds threshold"
def test_gelu_is_smooth_and_monotonic():
gelu = GELU()
x = torch.linspace(-5, 5, 100)
y = gelu(x)
dy = y[1:] - y[:-1]
# Проверяем, что функция GELU хотя бы локально монотонна на большинстве промежутков
assert (dy.mean() > 0 or dy.mean() < 0)
def test_gelu_gradients():
gelu = GELU()
x = torch.randn(3, 5, requires_grad=True)
y = gelu(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_gelu_large_vs_small():
gelu = GELU()
x_pos = torch.tensor([100.0])
x_neg = torch.tensor([-100.0])
y_pos = gelu(x_pos)
y_neg = gelu(x_neg)
# Для больших положительных GELU(x) ~ x, для больших отрицательных ~0
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4)
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4)

View File

@@ -0,0 +1,67 @@
import torch
import pytest
from llm.core.gemma_decoder import GemmaDecoder
from llm.core.rope import RoPE
@pytest.fixture
def gemma_decoder():
rope = RoPE(head_size=4, max_seq_len=32)
return GemmaDecoder(
num_q_heads=4,
emb_size=16,
head_size=4,
max_seq_len=32,
rope=rope,
dropout=0.1,
)
def test_forward_shape(gemma_decoder):
x = torch.randn(2, 12, 16)
out, cache = gemma_decoder(x)
assert out.shape == (2, 12, 16)
assert isinstance(cache, tuple) or cache is None
def test_forward_masked(gemma_decoder):
x = torch.randn(1, 8, 16)
mask = torch.ones(1, 8, 8, dtype=torch.bool)
out, _ = gemma_decoder(x, mask=mask)
assert out.shape == x.shape
def test_forward_with_cache_flag(gemma_decoder):
x = torch.randn(2, 7, 16)
out, cache = gemma_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 7, 16)
def test_forward_wrong_seq_len_raises(gemma_decoder):
x = torch.randn(1, 100, 16)
with pytest.raises(Exception):
gemma_decoder(x)
def test_gradient_flow(gemma_decoder):
x = torch.randn(3, 9, 16, requires_grad=True)
y, _ = gemma_decoder(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_various_shapes(gemma_decoder):
for b, s in [(1, 1), (2, 5), (2, 32)]:
x = torch.randn(b, s, 16)
y, _ = gemma_decoder(x)
assert y.shape == (b, s, 16)
def test_forward_repeatability():
torch.manual_seed(42)
rope = RoPE(head_size=4, max_seq_len=32)
decoder = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x = torch.randn(2, 8, 16)
y1, _ = decoder(x)
torch.manual_seed(42)
decoder2 = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x2 = torch.randn(2, 8, 16)
y2, _ = decoder2(x2)
assert torch.allclose(y1, y2, atol=1e-5)

View File

@@ -4,33 +4,43 @@ Tests for decoder block.
import pytest
import torch
from llm.core.decoder import Decoder
from llm.core.gpt_decoder import GptDecoder
class TestDecoder:
class TestGptDecoder:
"""Test cases for Decoder."""
def test_initialization(self, embed_dim, num_heads):
"""Test that Decoder can be initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
assert decoder is not None
# Check internal components
assert hasattr(decoder, '_heads')
assert hasattr(decoder, '_ff')
assert hasattr(decoder, '_norm1')
assert hasattr(decoder, '_norm2')
assert hasattr(decoder, "_heads")
assert hasattr(decoder, "_ff")
assert hasattr(decoder, "_norm1")
assert hasattr(decoder, "_norm2")
def test_forward_pass(self, embed_dim, num_heads, random_embeddings):
"""Test forward pass of Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -40,14 +50,19 @@ class TestDecoder:
"""Test forward pass with causal mask."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
batch_size, seq_len = random_embeddings.shape[:2]
# Create causal mask
mask = torch.tril(torch.ones(seq_len, seq_len))
# Forward pass with causal mask
output = decoder(random_embeddings, mask=mask)
output, _ = decoder(random_embeddings, attention_mask=mask)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -56,9 +71,14 @@ class TestDecoder:
"""Test that residual connections are properly applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# With residual connections and layer norm, the output shouldn't be
# too different from input (in terms of scale/distribution)
@@ -72,9 +92,14 @@ class TestDecoder:
"""Test that layer normalization is applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check that output has reasonable statistics (due to layer norm)
# Mean should be close to 0, std close to 1 for each sequence position
@@ -89,10 +114,15 @@ class TestDecoder:
"""Test that gradients flow through Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Create a dummy loss and backward pass
loss = output.sum()
@@ -109,11 +139,16 @@ class TestDecoder:
"""Test that Decoder works on correct device."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len).to(device)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
).to(device)
inputs = random_embeddings.to(device)
# Forward pass
output = decoder(inputs)
output, _ = decoder(inputs)
# Check device consistency
assert output.device == device
@@ -122,7 +157,7 @@ class TestDecoder:
def test_different_configurations(self):
"""Test Decoder with different configurations."""
test_cases = [
(64, 2), # embed_dim=64, num_heads=2
(64, 2), # embed_dim=64, num_heads=2
(128, 4), # embed_dim=128, num_heads=4
(256, 8), # embed_dim=256, num_heads=8
]
@@ -130,11 +165,16 @@ class TestDecoder:
for embed_dim, num_heads in test_cases:
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == inputs.shape
@@ -143,10 +183,15 @@ class TestDecoder:
"""Test Decoder with different input shapes."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == (batch_size, seq_len, embed_dim)
@@ -154,15 +199,21 @@ class TestDecoder:
"""Test that Decoder behaves differently in train vs eval mode."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len, dropout=0.5)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
dropout=0.5,
)
# Training mode
decoder.train()
output_train = decoder(random_embeddings)
output_train, _ = decoder(random_embeddings)
# Evaluation mode
decoder.eval()
output_eval = decoder(random_embeddings)
output_eval, _ = decoder(random_embeddings)
# Outputs should be different due to dropout
assert not torch.allclose(output_train, output_eval)
@@ -171,18 +222,20 @@ class TestDecoder:
"""Test that parameters are properly initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(num_heads=num_heads, emb_size=embed_dim, head_size=head_size, max_seq_len=max_seq_len)
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
# Check that various components have non-zero parameters
assert not torch.allclose(
decoder._heads._layer.weight,
torch.zeros_like(decoder._heads._layer.weight)
decoder._heads._layer.weight, torch.zeros_like(decoder._heads._layer.weight)
)
assert not torch.allclose(
decoder._ff._layer1.weight,
torch.zeros_like(decoder._ff._layer1.weight)
decoder._ff._layer1.weight, torch.zeros_like(decoder._ff._layer1.weight)
)
assert not torch.allclose(
decoder._norm1.weight,
torch.zeros_like(decoder._norm1.weight)
decoder._norm1.weight, torch.zeros_like(decoder._norm1.weight)
)

View File

@@ -0,0 +1,85 @@
# llm/tests/core/test_group_query_attention.py
import torch
import pytest
from llm.core.group_query_attention import GroupedQueryAttention
from llm.core.rope import RoPE
@pytest.fixture
def params():
return {
'num_q_heads': 4,
'num_kv_heads': 2,
'emb_size': 16,
'head_size': 4,
'max_seq_len': 32,
'window_size': 8,
'dropout': 0.0
}
def test_initialization(params):
attn = GroupedQueryAttention(**params)
assert isinstance(attn, GroupedQueryAttention)
def test_forward_shape(params):
batch, seq = 2, 10
x = torch.randn(batch, seq, params['emb_size'])
attn = GroupedQueryAttention(**params)
y, cache = attn(x)
assert y.shape == (batch, seq, params['emb_size'])
assert cache is not None
assert isinstance(y, torch.Tensor)
def test_forward_shape_with_mask(params):
batch, seq = 2, 10
x = torch.randn(batch, seq, params['emb_size'])
mask = torch.tril(torch.ones(seq, seq)).bool()
attn = GroupedQueryAttention(**params)
y, _ = attn(x, mask=mask)
assert y.shape == (batch, seq, params['emb_size'])
def test_kv_repetition(params):
batch, seq = 1, 3
attn = GroupedQueryAttention(**params)
kv = torch.randn(batch, params['num_kv_heads'], seq, params['head_size'])
rep = attn._repeat_kv_heads(kv, params['num_q_heads'], params['num_kv_heads'])
assert rep.shape == (batch, params['num_q_heads'], seq, params['head_size'])
def test_window_mask(params):
attn = GroupedQueryAttention(**params)
mask = attn._create_sliding_window_mask(8, 3)
assert mask.shape == (8, 8)
# Проверим булеву маску окна в позиции 4
expected = torch.tensor([True, True, True, True, False, False])
assert torch.equal(mask[4, 1:7], expected)
def test_forward_with_rope(params):
batch, seq = 2, 12
x = torch.randn(batch, seq, params['emb_size'])
rope = RoPE(head_size=params['head_size'], max_seq_len=params['max_seq_len'])
params2 = params.copy()
params2['rope'] = rope
attn = GroupedQueryAttention(**params2)
y, _ = attn(x)
assert y.shape == (batch, seq, params['emb_size'])
def test_cache_usage(params):
batch, seq = 1, 5
x = torch.randn(batch, seq, params['emb_size'])
attn = GroupedQueryAttention(**params)
# Первый проход - получаем кэш
_, cache = attn(x)
# Второй проход с кэшем (имитируем автокомплит seq_len=1)
x2 = torch.randn(batch, 1, params['emb_size'])
y2, cache2 = attn(x2, cache=cache)
assert cache2 is not None
assert y2.shape == (batch, 1, params['emb_size'])
def test_gradient_backward(params):
batch, seq = 2, 6
x = torch.randn(batch, seq, params['emb_size'], requires_grad=True)
attn = GroupedQueryAttention(**params)
y, _ = attn(x)
y.sum().backward()
for param in attn.parameters():
assert param.grad is not None

View File

@@ -0,0 +1,66 @@
import torch
import pytest
from llm.core.mistral_decoder import MistralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def decoder_config():
# Current MistralDecoder is a single block (not a stack).
return dict(
num_q_heads=4,
num_kv_heads=2,
emb_size=32,
head_size=8,
max_seq_len=128,
window_size=16,
rope=RoPE(head_size=8, max_seq_len=128),
dropout=0.0
)
def test_mistral_decoder_init(decoder_config):
model = MistralDecoder(**decoder_config)
assert model is not None
def test_mistral_decoder_forward_shapes(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=True)
assert output.shape == (batch, seq_len, emb_size)
assert cache is not None
def test_mistral_decoder_forward_no_cache(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
output, cache = model(x, use_cache=False)
assert output.shape == (batch, seq_len, emb_size)
assert cache is None
def test_mistral_decoder_cache_shapes(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 8, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
# Первый проход — без кэша
_, cache = model(x, use_cache=True)
# Второй проход — заполняем кэш
x_next = torch.randn(batch, 1, emb_size)
_, cache2 = model(x_next, use_cache=True, cache=cache)
# Можно проверить, что кэш не None и корректной структуры:
assert cache2 is not None
def test_mistral_decoder_shape_error(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, decoder_config['max_seq_len'] + 1, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size)
with pytest.raises(ValueError):
model(x)
def test_mistral_decoder_backward(decoder_config):
model = MistralDecoder(**decoder_config)
batch, seq_len, emb_size = 2, 10, decoder_config['emb_size']
x = torch.randn(batch, seq_len, emb_size, requires_grad=True)
output, _ = model(x, use_cache=False)
loss = output.sum()
loss.backward()
assert x.grad is not None

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def basic_decoder():
emb_size = 16
num_q_heads = 4
num_kv_heads = 2
head_size = 4
max_seq_len = 32
num_experts = 4
top_k_experts = 2
window_size = 8
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
return MixtralDecoder(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
num_experts=num_experts,
top_k_experts=top_k_experts,
window_size=window_size,
rope=rope,
dropout=0.0,
)
def test_forward_shape(basic_decoder):
x = torch.randn(2, 10, 16)
out, cache = basic_decoder(x)
assert out.shape == (2, 10, 16)
assert cache is None or isinstance(cache, (tuple, list))
def test_forward_masked(basic_decoder):
x = torch.randn(3, 7, 16)
mask = torch.ones(3, 7, 7, dtype=torch.bool)
out, cache = basic_decoder(x, mask=mask)
assert out.shape == (3, 7, 16)
def test_forward_with_cache_flag(basic_decoder):
x = torch.randn(2, 8, 16)
out, cache = basic_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 8, 16)
assert isinstance(cache, (tuple, list)) or cache is None
def test_backprop_pass(basic_decoder):
x = torch.randn(2, 5, 16, requires_grad=True)
out, _ = basic_decoder(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_seq_too_long_raises(basic_decoder):
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
with pytest.raises(Exception):
basic_decoder(x)
def test_different_config():
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
)
x = torch.randn(1, 8, 4)
out, cache = decoder(x)
assert out.shape == x.shape
def test_forward_no_dropout():
# Проверка на корректность shape при отсутствии Dropout
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
)
x = torch.randn(2, 3, 4)
out, cache = decoder(x)
assert out.shape == x.shape

View File

@@ -0,0 +1,61 @@
import torch
import pytest
from llm.core.moe import MoE
@pytest.fixture
def moe():
# Базовая MoE для коротких тестов
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
def test_forward_shape(moe):
x = torch.randn(3, 5, 16) # [batch, seq, emb]
y = moe(x)
assert y.shape == x.shape
def test_forward_grad(moe):
x = torch.randn(2, 4, 16, requires_grad=True)
y = moe(x)
(y.sum()).backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_top_k_larger_than_experts():
# top_k_experts > num_experts должно падать
with pytest.raises(ValueError):
MoE(emb_size=8, num_experts=2, top_k_experts=4)
def test_single_expert_no_error():
# один эксперт, один топ-к — модель всё ещё валидна
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
x = torch.randn(2, 2, 8)
y = moe(x)
assert y.shape == x.shape
def test_forward_trivial_weights():
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
class DummyMoE(MoE):
def forward(self, x):
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
torch.nn.init.constant_(self._router.weight, 0.0)
return super().forward(x)
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
x = torch.zeros(1, 2, 4)
y = moe(x)
assert y.shape == x.shape
def test_forward_deterministic_seed(moe):
torch.manual_seed(42)
x = torch.randn(2, 3, 16)
y1 = moe(x)
torch.manual_seed(42)
y2 = moe(x)
assert torch.allclose(y1, y2, atol=1e-5)
def test_forward_no_dropout():
"""Без dropout MoE не меняет shape и не даёт NaN."""
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
x = torch.randn(2, 7, 5)
y = moe(x)
assert y.shape == x.shape
assert not torch.isnan(y).any()

View File

@@ -13,18 +13,22 @@ class TestMultiHeadAttention:
def test_initialization(self, embed_dim, num_heads):
"""Test that MultiHeadAttention can be initialized."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
assert attention is not None
# Check internal attributes
assert len(attention._heads) == num_heads
assert attention._num_heads == num_heads
assert attention._layer.in_features == embed_dim
assert attention._layer.out_features == embed_dim
def test_forward_pass(self, embed_dim, num_heads, random_embeddings):
"""Test forward pass of MultiHeadAttention."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
# Forward pass
output, _ = attention(random_embeddings)
@@ -36,7 +40,9 @@ class TestMultiHeadAttention:
def test_forward_with_mask(self, embed_dim, num_heads, random_embeddings):
"""Test forward pass with attention mask."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
# Create a simple mask
seq_len = random_embeddings.shape[1]
@@ -51,7 +57,9 @@ class TestMultiHeadAttention:
def test_causal_mask(self, embed_dim, num_heads, random_embeddings):
"""Test that causal mask prevents attending to future positions."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
# Create causal mask
seq_len = random_embeddings.shape[1]
@@ -63,10 +71,14 @@ class TestMultiHeadAttention:
# Check output shape
assert output.shape == random_embeddings.shape
def test_attention_weights_normalization(self, embed_dim, num_heads, random_embeddings):
def test_attention_weights_normalization(
self, embed_dim, num_heads, random_embeddings
):
"""Test that attention weights are properly normalized."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
# Forward pass
output, _ = attention(random_embeddings)
@@ -77,7 +89,9 @@ class TestMultiHeadAttention:
def test_gradient_flow(self, embed_dim, num_heads, random_embeddings):
"""Test that gradients flow through MultiHeadAttention."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
# Forward pass
output, _ = attention(random_embeddings)
@@ -88,13 +102,17 @@ class TestMultiHeadAttention:
# Check that gradients are computed for learnable parameters
assert attention._layer.weight.grad is not None
if len(attention._heads) > 0:
assert attention._heads[0]._q.weight.grad is not None
# Проверяем, что также у градиентов весов q/k/v есть значения
assert attention._q.weight.grad is not None
assert attention._k.weight.grad is not None
assert attention._v.weight.grad is not None
def test_device_consistency(self, embed_dim, num_heads, random_embeddings, device):
"""Test that MultiHeadAttention works on correct device."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024).to(device)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
).to(device)
inputs = random_embeddings.to(device)
# Forward pass
@@ -107,15 +125,17 @@ class TestMultiHeadAttention:
def test_different_embed_dim_and_heads(self):
"""Test MultiHeadAttention with different embed_dim and num_heads combinations."""
test_cases = [
(64, 2), # embed_dim=64, num_heads=2
(64, 2), # embed_dim=64, num_heads=2
(128, 4), # embed_dim=128, num_heads=4
(256, 8), # embed_dim=256, num_heads=8
(512, 16), # embed_dim=512, num_heads=16
(512, 16), # embed_dim=512, num_heads=16
]
for embed_dim, num_heads in test_cases:
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, embed_dim)
@@ -126,7 +146,9 @@ class TestMultiHeadAttention:
def test_attention_output_range(self, embed_dim, num_heads, random_embeddings):
"""Test that attention output is in reasonable range."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
output, _ = attention(random_embeddings)
@@ -137,7 +159,9 @@ class TestMultiHeadAttention:
def test_different_input_shapes(self, embed_dim, num_heads, batch_size, seq_len):
"""Test MultiHeadAttention with different input shapes."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024)
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024
)
inputs = torch.randn(batch_size, seq_len, embed_dim)
output, _ = attention(inputs)
@@ -147,7 +171,9 @@ class TestMultiHeadAttention:
def test_parameter_sharing(self, embed_dim, num_heads):
"""Test that parameters are properly shared across the sequence."""
head_size = embed_dim // num_heads
attention = MultiHeadAttention(num_heads, embed_dim, head_size, max_seq_len=1024, dropout=0.0) # No dropout for deterministic test
attention = MultiHeadAttention(
num_heads, embed_dim, head_size, max_seq_len=1024, dropout=0.0
) # No dropout for deterministic test
# Create two identical sequences
seq_len = 10

View File

@@ -0,0 +1,71 @@
import torch
import pytest
from llm.core.multi_query_attention import MultiQueryAttention
from llm.core.rope import RoPE
@pytest.fixture
def mqa_rope():
return MultiQueryAttention(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=RoPE(head_size=4, max_seq_len=32), dropout=0.1
)
@pytest.fixture
def mqa_no_rope():
return MultiQueryAttention(
num_q_heads=2, emb_size=8, head_size=4, max_seq_len=16, rope=None, dropout=0.0
)
def test_forward_shape(mqa_rope):
x = torch.randn(2, 10, 16)
out, cache = mqa_rope(x)
assert out.shape == (2, 10, 16)
assert isinstance(cache, tuple) and len(cache) == 2
def test_forward_masked(mqa_rope):
x = torch.randn(2, 8, 16)
mask = torch.ones(2, 8, 8, dtype=torch.bool)
out, cache = mqa_rope(x, mask=mask)
assert out.shape == (2, 8, 16)
def test_forward_cache(mqa_rope):
x = torch.randn(1, 4, 16)
# Первый вызов — кэша нет
out1, cache1 = mqa_rope(x)
# Повторяем: подаем x второй раз — теперь добавим cache
out2, cache2 = mqa_rope(x, use_cache=True, cache=cache1)
assert out2.shape == (1, 4, 16)
assert isinstance(cache2, tuple) and len(cache2) == 2
# Проверка, что длина k_cache увеличилась
assert cache2[0].shape[2] == cache1[0].shape[2] + x.shape[1] # по длине seq
def test_forward_no_rope(mqa_no_rope):
x = torch.randn(3, 6, 8)
out, _ = mqa_no_rope(x)
assert out.shape == (3, 6, 8)
def test_forward_different_batch_seq(mqa_rope):
for batch, seq in [(1, 1), (2, 5), (3, 32)]:
x = torch.randn(batch, seq, 16)
out, _ = mqa_rope(x)
assert out.shape == (batch, seq, 16)
def test_forward_raise_on_long_seq(mqa_rope):
x = torch.randn(2, 40, 16) # seq_len > max_seq_len
with pytest.raises(ValueError):
mqa_rope(x)
def test_forward_grad(mqa_rope):
x = torch.randn(2, 7, 16, requires_grad=True)
out, _ = mqa_rope(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_dropout_applied():
mqa = MultiQueryAttention(num_q_heads=2, emb_size=8, head_size=4, max_seq_len=12, rope=None, dropout=0.99)
x = torch.ones(1, 3, 8)
mqa.train()
y, _ = mqa(x)
# При очень большом dropout почти всё обнуляется
assert (torch.abs(y) < 1e-5).float().mean() > 0.6 or y.sum() < 1e-2

View File

@@ -18,7 +18,7 @@ class TestPositionalEmbeddings:
assert embeddings is not None
# Check that positional embeddings are created
assert hasattr(embeddings, 'embedding')
assert hasattr(embeddings, "embedding")
assert embeddings.embedding.weight.shape == (max_seq_len, embed_dim)
def test_forward_pass(self, embed_dim):
@@ -52,7 +52,7 @@ class TestPositionalEmbeddings:
def test_different_sequence_lengths(self, embed_dim):
"""Test PositionalEmbeddings with different sequence lengths."""
test_cases = [
(10, 5), # seq_len < max_seq_len
(10, 5), # seq_len < max_seq_len
(10, 10), # seq_len == max_seq_len
]
@@ -80,8 +80,10 @@ class TestPositionalEmbeddings:
# Positional embeddings should have gradients (they're learnable)
assert embeddings.embedding.weight.grad is not None
assert not torch.allclose(embeddings.embedding.weight.grad,
torch.zeros_like(embeddings.embedding.weight.grad))
assert not torch.allclose(
embeddings.embedding.weight.grad,
torch.zeros_like(embeddings.embedding.weight.grad),
)
def test_device_consistency(self, embed_dim, device):
"""Test that PositionalEmbeddings works on correct device."""
@@ -103,7 +105,9 @@ class TestPositionalEmbeddings:
embeddings2 = PositionalEmbeddings(max_seq_len, embed_dim)
# Different instances should have different embeddings (random initialization)
assert not torch.allclose(embeddings1.embedding.weight, embeddings2.embedding.weight)
assert not torch.allclose(
embeddings1.embedding.weight, embeddings2.embedding.weight
)
# But same instance should produce same output for same input
seq_len = 50
@@ -122,11 +126,14 @@ class TestPositionalEmbeddings:
assert not torch.allclose(pe[0], pe[1], rtol=1e-4)
assert not torch.allclose(pe[10], pe[20], rtol=1e-4)
@pytest.mark.parametrize("max_seq_len,seq_len,embed_dim", [
(64, 10, 64),
(128, 50, 128),
(256, 100, 256),
])
@pytest.mark.parametrize(
"max_seq_len,seq_len,embed_dim",
[
(64, 10, 64),
(128, 50, 128),
(256, 100, 256),
],
)
def test_different_configurations(self, max_seq_len, seq_len, embed_dim):
"""Test PositionalEmbeddings with different configurations."""
embeddings = PositionalEmbeddings(max_seq_len, embed_dim)

View File

@@ -0,0 +1,47 @@
import torch
import pytest
from llm.core.rms_norm import RMSNorm
def test_rmsnorm_shape_preservation():
norm = RMSNorm(64)
x = torch.randn(3, 5, 64)
y = norm(x)
assert y.shape == x.shape
def test_rmsnorm_dtype_and_device():
norm = RMSNorm(32)
x = torch.randn(8, 32, device='cpu', dtype=torch.float64)
y = norm(x)
assert y.dtype == torch.float64
assert y.device == x.device
def test_rmsnorm_mean_no_shift():
norm = RMSNorm(32)
x = torch.randn(3, 128, 32)
y = norm(x)
rms = torch.sqrt((y ** 2).mean(dim=-1))
w_mean = norm._w.mean().item()
assert torch.allclose(rms.mean(), torch.tensor(w_mean), rtol=0.2, atol=0.2)
def test_rmsnorm_backward():
norm = RMSNorm(16)
x = torch.randn(2, 15, 16, requires_grad=True)
y = norm(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert norm._w.grad is not None
def test_rmsnorm_fp16():
norm = RMSNorm(8).half()
x = torch.randn(2, 6, 8).half()
y = norm(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_rmsnorm_large_eps_stability():
norm = RMSNorm(16, eps=1)
x = torch.zeros(2, 5, 16)
y = norm(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()

View File

@@ -0,0 +1,55 @@
import torch
import pytest
from llm.core.rope import RoPE
def test_rope_shapes_and_dtype():
rope = RoPE(head_size=8, max_seq_len=32)
x = torch.randn(2, 4, 16, 8) # [batch, num_heads, seq_len, head_size]
y = rope(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_rope_raises_on_bad_ndim():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 16, 8) # [batch, seq_len, head_size] (3D)
with pytest.raises(AssertionError):
_ = rope(x)
def test_rope_preserves_norm():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 3, 7, 8)
x_norm = x.norm(dim=-1)
y = rope(x)
y_norm = y.norm(dim=-1)
# Нормы могут немного отличаться из-за float, сравниваем с допуском
assert torch.allclose(x_norm, y_norm, rtol=1e-5, atol=1e-7)
def test_rope_backward_pass():
rope = RoPE(head_size=8, max_seq_len=16)
x = torch.randn(2, 2, 8, 8, requires_grad=True)
out = rope(x)
loss = out.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
@pytest.mark.parametrize("batch,num_heads,seq_len,head_size", [
(1, 1, 4, 8),
(2, 4, 16, 8),
(3, 2, 7, 8),
])
def test_rope_various_shapes(batch, num_heads, seq_len, head_size):
rope = RoPE(head_size=head_size, max_seq_len=32)
x = torch.randn(batch, num_heads, seq_len, head_size)
y = rope(x)
assert y.shape == x.shape
def test_rope_start_pos():
rope = RoPE(head_size=8, max_seq_len=32)
x_full = torch.randn(1, 2, 8, 8)
# Сравниваем участок результата для разных start_pos
out1 = rope(x_full)
out2 = rope(x_full, start_pos=2)
assert not torch.allclose(out1, out2)
# Для одинакового start_pos и x должны совпадать
assert torch.allclose(rope(x_full, start_pos=1), rope(x_full, start_pos=1))

View File

@@ -0,0 +1,42 @@
import torch
import pytest
from llm.core.silu import SiLU
def test_silu_shape_and_dtype():
silu = SiLU()
x = torch.randn(3, 10, 8)
y = silu(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_silu_known_values():
silu = SiLU()
x = torch.tensor([-2.0, 0.0, 2.0])
y = silu(x)
# PyTorch эталон
y_ref = torch.nn.functional.silu(x)
assert torch.allclose(y, y_ref, atol=1e-6)
def test_silu_large_vs_small():
silu = SiLU()
x_pos = torch.tensor([100.0])
x_neg = torch.tensor([-100.0])
y_pos = silu(x_pos)
y_neg = silu(x_neg)
assert torch.allclose(y_pos, x_pos, rtol=1e-4, atol=1e-4) # SiLU(x) ~ x для больших x>0
assert torch.allclose(y_neg, torch.zeros_like(x_neg), rtol=1e-4, atol=1e-4) # SiLU(x) ~ 0 для x<0
def test_silu_gradients():
silu = SiLU()
x = torch.randn(4, 4, requires_grad=True)
y = silu(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_silu_broadcast():
silu = SiLU()
x = torch.randn(3, 1, 16)
y = silu(x)
assert y.shape == x.shape

View File

@@ -0,0 +1,39 @@
import torch
import pytest
from llm.core.swi_glu import SwiGLU
def test_swiglu_shape_and_dtype():
swiglu = SwiGLU(emb_size=32, dropout=0.1)
x = torch.randn(4, 10, 32)
y = swiglu(x)
assert y.shape == x.shape
assert y.dtype == x.dtype
def test_swiglu_forward_range():
swiglu = SwiGLU(emb_size=16, dropout=0.0)
x = torch.randn(3, 7, 16)
y = swiglu(x)
assert y.abs().max() < 20
def test_swiglu_gradients():
swiglu = SwiGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 5, 8, requires_grad=True)
out = swiglu(x)
loss = out.pow(2).sum()
loss.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_swiglu_fp16():
swiglu = SwiGLU(emb_size=16, dropout=0.0).half()
x = torch.randn(1, 8, 16).half()
y = swiglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_swiglu_reproducibility():
swiglu = SwiGLU(emb_size=8, dropout=0.0)
x = torch.ones(2, 4, 8)
y1 = swiglu(x)
y2 = swiglu(x)
assert torch.allclose(y1, y2)

Some files were not shown because too many files have changed in this diff Show More