12 Commits

Author SHA1 Message Date
Sergey Penkovsky
7744658716 Merge pull request #6 from pese-git/ref/gpt1
Ref/gpt1
2025-10-31 09:15:54 +03:00
Sergey Penkovsky
21cfd79c19 refactor(assets): update and reorganize GPT-1 architecture diagrams
- Renamed GPT-1 main scheme files for clarity
- Added new diagram files for attention, decoder, embeddings, and forward blocks (both .drawio and .png)
- Removed deprecated files (gpt11.drawio, gpt1.svg)
- Updated notebooks/gpt.ipynb with relevant changes
2025-10-30 14:40:31 +03:00
Sergey Penkovsky
9e2796e6be docs(gpt1): add architecture diagrams and notebook updates
- Added architecture diagrams for GPT-1: gpt1.drawio, gpt11.drawio (drawio format)
- Exported visualization images: gpt1.png, gpt1.svg for documentation and presentations
- Updated gpt.ipynb notebook to reference new materials and possibly add explanations of layers/logic
- New assets help to clarify model structure and training flow for both contributors and external users
2025-10-24 17:42:11 +03:00
Sergey Penkovsky
25caf69ced refactor(gpt1): migrate Decoder to GptDecoder, unify API, and update tests
- Renamed Decoder (and decoder.py) to GptDecoder (gpt_decoder.py) for clarity in GPT1
- Implemented support for cache and use_cache parameters in GptDecoder.forward (API unification)
- Adapted all usages in GPT model to use new decoder structure and handle tuple output
- Refactored core tests (test_gpt.py, test_gpt_decoder.py, test_basic.py) to correctly expect tuple or logits and ensure shape/device checks work as before
- Improved clarity and future extensibility for autoregressive generation and benchmarking
- No changes to architectural details or training loop; pure API and test modernization
2025-10-22 16:27:08 +03:00
Sergey Penkovsky
ddc4924a37 refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms. 2025-10-22 11:57:26 +03:00
Sergey Penkovsky
92a34551b8 Merge pull request #5 from pese-git/feature/gemma
Feature/gemma
2025-10-21 17:53:55 +03:00
Sergey Penkovsky
ea932a36f3 feat(gemma): document and test GeGLU, MultiQueryAttention, GemmaDecoder, update Gemma model docs
- Add new core modules: GeGLU (Gated GELU Linear Unit), GemmaDecoder, MultiQueryAttention; all with highly detailed scientific (RU) docstrings: theory, usage, formulas, references
- Major doc improvements in Gemma model: class, __init__, forward, generate now have full educational/engineering docstrings, use-case samples, and literature links
- Add comprehensive unit tests:
    * tests/core/test_geglu.py: GeGLU coverage (shape, grads, edge, repeat, float16/skip)
    * tests/core/test_gemma_decoder.py: GemmaDecoder coverage (shape, mask, cache, repeatability, errors)
    * tests/core/test_multi_query_attention.py: MQA coverage (shape, cache, gradients, masking, dropout, raise)
- All modules and tests follow strict quality/documentation standards, code is now robust for research & production
2025-10-21 15:12:45 +03:00
Sergey Penkovsky
cfb4b6dfb1 feat(gemma): initial implementation of Gemma model and configs
- Add core Gemma model (architecture, attention, GeGLU, RoPE, RMSNorm, etc)
- Add configs for training and generation: gemma_train.json, gemma_generate.json
- Add Gemma notebook for exploratory analysis and demonstration
- Add __init__.py for Gemma submodule
- Update run_llm_experiment.py to support Gemma experiment configs

test(gemma): add comprehensive unit tests for Gemma

- Test forward pass (with/without cache)
- Test autoregressive generation (greedy, top-k, top-p)
- Test shape correctness and max sequence length errors
- Test multi-layer stack and token embeddings

docs: add documentation notebook for Gemma usage and analysis

Closes: #issue (if applicable)
2025-10-21 01:02:15 +03:00
Sergey Penkovsky
58c4a00b48 Merge pull request #4 from pese-git/feature/mixtral
Feature/mixtral
2025-10-20 16:36:39 +03:00
Sergey Penkovsky
c9da4c841b feat(mixtral): add MixtralDecoder, enhance MoE and Mixtral model docs, add unit tests
- Implement new core module: MixtralDecoder (llm/core/mixtral_decoder.py) with full Russian scientific docstrings, formal math, and usage examples
- Improve MoE: add Russian docstrings for class, __init__, forward; validate top_k_experts; explain theory and components
- Refactor Mixtral model: switch stack to MixtralDecoder, add comprehensive documentation for class, constructor and forward, clarify config usage and architecture
- Add thorough unit tests:
   * tests/core/test_mixtral_decoder.py: checks shapes, errors, mask, dropout, grads etc.
   * tests/core/test_moe.py: covers normal and edge-case logic, gradients, shape, params check
- All code and tests in compliance with recent scientific and engineering standards.
2025-10-20 16:07:51 +03:00
Sergey Penkovsky
b1737bbce2 feat(mixtral): initial implementation of Mixtral MoE model, configs, and tests
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py)
- Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py)
- Create dedicated configuration files for Mixtral training and generation experiments
- Register and test Mixtral support in experiment runner (run_llm_experiment.py)
- Add unit tests for Mixtral API including forward, caching, and generation modes
- Include Jupyter notebook mixstral.ipynb for architectural exploration and research
- Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation

BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
2025-10-20 08:12:11 +03:00
Sergey Penkovsky
1aba02cab9 Merge pull request #3 from pese-git/feature/mistral
Feature/mistral
2025-10-17 20:45:20 +03:00
42 changed files with 6376 additions and 87 deletions

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" vertex="1" connectable="0" parent="1">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,413 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="2176" dy="702" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;container=0;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="281.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="580" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="490.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="609.9976190476191" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="407.1428571428571" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="250" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="385.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="459.5238095238095" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="380.00428571428574" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="443.80952380952385" y="410" as="sourcePoint"/>
<mxPoint x="375.7142857142858" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="466.8571428571429" y="80" as="sourcePoint"/>
<mxPoint x="579.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="466.67904761904765" y="125"/>
<mxPoint x="522.3809523809523" y="125"/>
<mxPoint x="585" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="580.0019047619048" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="264.3255813953488" y="80" as="sourcePoint"/>
<mxPoint x="380.7142857142858" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="264.2585825027686" y="130"/>
<mxPoint x="319.96048726467325" y="130"/>
<mxPoint x="385" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="141" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="92">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="180" y="250" as="sourcePoint"/>
<mxPoint x="281" y="110" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="142" value="" style="endArrow=none;dashed=1;html=1;entryX=1;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="1" target="4">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="620" y="560" as="sourcePoint"/>
<mxPoint x="660" y="520" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="218" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
<mxGeometry x="130" y="660" width="680" height="160" as="geometry"/>
</mxCell>
<mxCell id="195" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="147">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="196" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="148">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="197" style="edgeStyle=orthogonalEdgeStyle;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="143" target="149">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="143" value="X" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="147" value="W&lt;sub&gt;k&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="199" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="148" target="151">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="148" value="W&lt;sub&gt;q&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="149" value="W&lt;sub&gt;v&lt;/sub&gt;" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="80" y="120" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="207" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="150" target="158">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="236" y="20"/>
<mxPoint x="236" y="50"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="150" value="K" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" width="41.03" height="40" as="geometry"/>
</mxCell>
<mxCell id="208" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="151" target="190">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="151" value="Q" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" y="60" width="41.03" height="40" as="geometry"/>
</mxCell>
<mxCell id="214" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="152" target="187">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="600" y="140"/>
<mxPoint x="600" y="80"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="152" value="V" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="158.97000000000003" y="120" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="215" style="edgeStyle=none;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="187">
<mxGeometry relative="1" as="geometry">
<mxPoint x="680" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="187" value="O" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#fff2cc;strokeColor=#d6b656;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="218">
<mxGeometry x="620" y="60" width="40" height="40" as="geometry"/>
</mxCell>
<mxCell id="211" style="edgeStyle=none;html=1;entryX=0;entryY=0;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="188" target="179">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="188" value="Scale" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#f8cecc;strokeColor=#b85450;" vertex="1" parent="218">
<mxGeometry x="370" y="37.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="213" style="edgeStyle=orthogonalEdgeStyle;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="189" target="187">
<mxGeometry relative="1" as="geometry">
<Array as="points">
<mxPoint x="600" y="50"/>
<mxPoint x="600" y="80"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="189" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;direction=east;rotation=90;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="218">
<mxGeometry x="530" y="37.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="209" style="edgeStyle=none;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="190" target="188">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="190" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
<mxGeometry x="272.5" y="10" width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="153" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="154" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="155" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="156" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="157" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="158" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="159" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="160" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="161" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="162" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="163" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="164" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="165" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="166" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="167" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="168" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="169" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="190">
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="191" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="218">
<mxGeometry x="440" y="10" width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="170" value="" style="whiteSpace=wrap;html=1;aspect=fixed;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry width="80" height="80" as="geometry"/>
</mxCell>
<mxCell id="171" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="172" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="173" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="174" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="175" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="176" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="177" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="178" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="20" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="179" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="180" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="181" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="182" value="" style="rounded=0;whiteSpace=wrap;html=1;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="40" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="183" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="184" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="20" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="185" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="40" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="186" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#dae8fc;strokeColor=#6c8ebf;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" parent="191">
<mxGeometry x="60" y="60" width="20" height="20" as="geometry"/>
</mxCell>
<mxCell id="198" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="147" target="150">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="200" style="edgeStyle=none;html=1;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" target="152">
<mxGeometry relative="1" as="geometry">
<mxPoint x="120" y="140" as="sourcePoint"/>
<mxPoint x="148.97000000000008" y="139.8599999999998" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="212" value="" style="endArrow=classic;html=1;exitX=1;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" edge="1" parent="218" source="182" target="189">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="610" y="50" as="sourcePoint"/>
<mxPoint x="660" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="219" value="" style="group;movable=0;resizable=0;rotatable=0;deletable=0;editable=0;locked=1;connectable=0;" vertex="1" connectable="0" parent="1">
<mxGeometry x="289.99776556776555" y="520" width="250.00223443223445" height="90" as="geometry"/>
</mxCell>
<mxCell id="145" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="133" target="144">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="133" value="Concat" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
<mxGeometry x="132.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="136" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" target="133">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="108.97435897435912" y="45" as="sourcePoint"/>
<mxPoint x="250.00223443223445" y="25" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="129" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="30" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="130" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="20" y="10" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="131" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry x="10" y="20" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="132" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" vertex="1" parent="219">
<mxGeometry y="30" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="146" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" edge="1" parent="219" source="144">
<mxGeometry relative="1" as="geometry">
<mxPoint x="250.00223443223445" y="44.969696969697" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="144" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;direction=east;rotation=90;" vertex="1" parent="219">
<mxGeometry x="182.50223443223445" y="32.5" width="50" height="25" as="geometry"/>
</mxCell>
<mxCell id="220" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="90" y="690" as="sourcePoint"/>
<mxPoint x="290" y="610" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="221" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="370" y="610" as="sourcePoint"/>
<mxPoint x="830" y="700" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="979" dy="301" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="92" value="" style="group" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="320" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;strokeColor=#FF3333;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,148 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="1216" dy="316" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="91" value="" style="group;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="1" vertex="1" connectable="0">
<mxGeometry x="40" y="360" width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="56" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="57" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="58" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="59" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="59" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="60" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="61" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="62" target="59" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="62" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="63" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="56" target="57" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="64" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="65" target="62" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="65" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="66" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="57" target="65" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="67" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="69" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="68" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" source="69" target="60" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="69" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="70" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="56" target="65" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="71" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="72" target="56" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="72" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#FF3333;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="73" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="74" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="75" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="76" target="79" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="76" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="77" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="60" target="76" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="78" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="79" target="85" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="79" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="80" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="81" target="83" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="81" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="82" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="83" target="87" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="83" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="84" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="85" target="81" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="85" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="86" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="87" target="88" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="87" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="88" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="89" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="90" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="91" source="89" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

View File

@@ -0,0 +1,192 @@
<mxfile host="65bd71144e">
<diagram name="GPT Architecture" id="DEYydPS-O6mnllJWumln">
<mxGraphModel dx="2176" dy="1029" grid="1" gridSize="10" guides="1" tooltips="1" connect="1" arrows="1" fold="1" page="1" pageScale="1" pageWidth="827" pageHeight="1169" math="0" shadow="0">
<root>
<mxCell id="0"/>
<mxCell id="1" parent="0"/>
<mxCell id="107" value="" style="group" vertex="1" connectable="0" parent="1">
<mxGeometry x="120" y="170" width="1286" height="265" as="geometry"/>
</mxCell>
<mxCell id="92" value="" style="group" parent="107" vertex="1" connectable="0">
<mxGeometry width="1286" height="160" as="geometry"/>
</mxCell>
<mxCell id="3" value="" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#fff2cc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="230" width="440" height="160" as="geometry"/>
</mxCell>
<mxCell id="4" value="&lt;div&gt;Masked&lt;/div&gt;Multi+Head&lt;br&gt;Attention" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="51.42776556776556" y="50" width="78.97435897435898" height="60" as="geometry"/>
</mxCell>
<mxCell id="22" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="5" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="350" y="80" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="5" value="Feed&lt;div&gt;Forward&lt;/div&gt;&lt;div&gt;Network&lt;/div&gt;" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#FF3333;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="260.9564102564102" y="50" width="71.9230769230769" height="60" as="geometry"/>
</mxCell>
<mxCell id="7" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="379.997619047619" y="60" width="37.87142857142857" height="40" as="geometry"/>
</mxCell>
<mxCell id="21" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="12" target="5" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="12" value="Norm" style="rounded=0;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="177.14285714285714" y="60" width="41.904761904761905" height="40" as="geometry"/>
</mxCell>
<mxCell id="13" value="" style="endArrow=classic;html=1;exitX=0;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=elbowEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="3" target="4" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="20" y="80.00000000000011" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="14" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="18" target="12" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="155.71428571427464" y="79.99999999999989" as="sourcePoint"/>
<mxPoint x="229.52380952380952" y="-50" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="18" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="150.00428571428571" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="19" value="" style="endArrow=classic;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;edgeStyle=orthogonalEdgeStyle;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="4" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="213.80952380952382" y="410" as="sourcePoint"/>
<mxPoint x="145.71428571428578" y="80.00000000000011" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="23" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="24" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="236.85714285714286" y="80" as="sourcePoint"/>
<mxPoint x="349.7619047619048" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="236.67904761904765" y="125"/>
<mxPoint x="292.38095238095235" y="125"/>
<mxPoint x="355" y="125"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="28" value="" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" source="24" target="7" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="24" value="+" style="ellipse;whiteSpace=wrap;html=1;aspect=fixed;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" vertex="1">
<mxGeometry x="350.00190476190477" y="75" width="10" height="10" as="geometry"/>
</mxCell>
<mxCell id="25" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="3" target="18" edge="1">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="34.325581395348834" y="80" as="sourcePoint"/>
<mxPoint x="150.71428571428578" y="85" as="targetPoint"/>
<Array as="points">
<mxPoint x="34.25858250276859" y="130"/>
<mxPoint x="89.96048726467328" y="130"/>
<mxPoint x="155" y="130"/>
</Array>
</mxGeometry>
</mxCell>
<mxCell id="104" value="" style="endArrow=none;dashed=1;html=1;" edge="1" parent="3">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="200" y="200" as="sourcePoint"/>
<mxPoint x="260.96000000000004" y="110" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="36" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="32" target="3" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="32" value="+" style="rounded=0;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="90" width="110" height="160" as="geometry"/>
</mxCell>
<mxCell id="33" value="Token Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#6c8ebf;fillColor=#dae8fc;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="17.5" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="34" value="Position Emb" style="rounded=0;whiteSpace=wrap;html=1;strokeColor=#9673a6;fillColor=#e1d5e7;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="95" y="100" width="100" height="42.5" as="geometry"/>
</mxCell>
<mxCell id="46" style="edgeStyle=none;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="37" target="40" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="37" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="690" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="38" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="7" target="37" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="47" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="40" target="44" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="40" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="790" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="49" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="41" target="42" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="41" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="950" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="52" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="42" target="50" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="42" value="Decoder" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#d5e8d4;strokeColor=#82b366;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1050" width="70" height="160" as="geometry"/>
</mxCell>
<mxCell id="48" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="44" target="41" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="44" value=".&lt;div&gt;.&lt;/div&gt;&lt;div&gt;.&lt;/div&gt;" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="890" y="40" width="30" height="80" as="geometry"/>
</mxCell>
<mxCell id="53" style="edgeStyle=none;html=1;exitX=1;exitY=0.5;exitDx=0;exitDy=0;entryX=0;entryY=0.5;entryDx=0;entryDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="50" target="51" edge="1">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="50" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1150" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="51" value="Softmax" style="rounded=1;whiteSpace=wrap;html=1;fillColor=#e1d5e7;strokeColor=#9673a6;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry x="1236" y="5" width="50" height="150" as="geometry"/>
</mxCell>
<mxCell id="54" value="Tokens" style="rounded=1;whiteSpace=wrap;html=1;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" vertex="1">
<mxGeometry y="40" width="60" height="90" as="geometry"/>
</mxCell>
<mxCell id="55" style="edgeStyle=none;html=1;entryX=-0.025;entryY=0.538;entryDx=0;entryDy=0;entryPerimeter=0;exitX=1;exitY=0.5;exitDx=0;exitDy=0;movable=1;resizable=1;rotatable=1;deletable=1;editable=1;locked=0;connectable=1;" parent="92" source="54" edge="1">
<mxGeometry relative="1" as="geometry">
<mxPoint x="42.75" y="84.66941747572821" as="sourcePoint"/>
<mxPoint x="90" y="85.33000000000004" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="105" value="" style="endArrow=none;dashed=1;html=1;exitX=1;exitY=1;exitDx=0;exitDy=0;" edge="1" parent="107" source="5">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint x="660" y="140" as="sourcePoint"/>
<mxPoint x="620" y="190" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="106" value="" style="group" vertex="1" connectable="0" parent="107">
<mxGeometry x="450" y="195" width="170" height="70" as="geometry"/>
</mxCell>
<mxCell id="100" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="93" target="99">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="93" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
<mxGeometry x="-5" y="20" width="70" height="30" as="geometry"/>
</mxCell>
<mxCell id="96" value="" style="endArrow=classic;html=1;entryX=0.5;entryY=1;entryDx=0;entryDy=0;" edge="1" parent="106" target="93">
<mxGeometry width="50" height="50" relative="1" as="geometry">
<mxPoint y="35" as="sourcePoint"/>
<mxPoint y="35.00999999999999" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="102" style="edgeStyle=none;html=1;exitX=0.5;exitY=0;exitDx=0;exitDy=0;" edge="1" parent="106" source="98">
<mxGeometry relative="1" as="geometry">
<mxPoint x="170" y="35.09433962264154" as="targetPoint"/>
</mxGeometry>
</mxCell>
<mxCell id="98" value="Linear" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#f5f5f5;fontColor=#333333;strokeColor=#666666;" vertex="1" parent="106">
<mxGeometry x="100" y="20" width="70" height="30" as="geometry"/>
</mxCell>
<mxCell id="101" value="" style="edgeStyle=none;html=1;" edge="1" parent="106" source="99" target="98">
<mxGeometry relative="1" as="geometry"/>
</mxCell>
<mxCell id="99" value="ReLU" style="rounded=1;whiteSpace=wrap;html=1;rotation=90;container=0;fillColor=#e1d5e7;strokeColor=#9673a6;" vertex="1" parent="106">
<mxGeometry x="50" y="20" width="70" height="30" as="geometry"/>
</mxCell>
</root>
</mxGraphModel>
</diagram>
</mxfile>

Binary file not shown.

After

Width:  |  Height:  |  Size: 46 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 50 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 62 KiB

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/gemma-bpe/config.json",
"model_weights": "checkpoints/gemma-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/gemma_only_generation_logs.json"
}

View File

@@ -0,0 +1,28 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"num_experts": 8,
"top_k_experts": 2,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/gemma-bpe/model.pt",
"model_config_path": "checkpoints/gemma-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/gemma_only_training_logs.json"
}

View File

@@ -0,0 +1,19 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"test_prompts": [
"Open weights",
"The Llama model is",
"Efficient transformers"
],
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"generation": {
"max_new_tokens": 40,
"temperature": 0.8,
"do_sample": true,
"top_k": null,
"top_p": null
},
"log_path": "checkpoints/mixtral_only_generation_logs.json"
}

View File

@@ -0,0 +1,28 @@
{
"bpe_tokenizer": "checkpoints/bpe_tokenizer.json",
"bpe_vocab_size": 1000,
"bpe_special_tokens": ["<pad>", "<unk>", "<bos>", "<eos>"],
"test_prompts": ["Open source AI", "What is Llama?"],
"model_config": {
"vocab_size": null,
"embed_dim": 256,
"num_q_heads": 4,
"num_kv_heads": 2,
"head_size": 64,
"num_layers": 4,
"max_position_embeddings": 512,
"num_experts": 8,
"top_k_experts": 2,
"window_size": 16,
"dropout": 0.1
},
"model_weights": "checkpoints/mixtral-bpe/model.pt",
"model_config_path": "checkpoints/mixtral-bpe/config.json",
"training": {
"learning_rate": 0.0003,
"batch_size": 2,
"num_epochs": 3,
"warmup_steps": 50
},
"log_path": "checkpoints/mixtral_only_training_logs.json"
}

View File

@@ -45,6 +45,12 @@ def load_model_class(model_name):
elif model_name.lower() == 'mistral':
from llm.models.mistral import Mistral
return Mistral
elif model_name.lower() == 'mixtral':
from llm.models.mixtral import Mixtral
return Mixtral
elif model_name.lower() == 'gemma':
from llm.models.gemma import Gemma
return Gemma
else:
raise ValueError(f"Модель '{model_name}' не поддерживается.")

140
llm/src/llm/core/geglu.py Normal file
View File

@@ -0,0 +1,140 @@
import torch
from torch import nn
from llm.core.gelu import GELU
class GeGLU(nn.Module):
"""
GeGLU (Gated GELU Linear Unit) — эффективная нелинейность для feed-forward блоков в современных трансформерах.
Назначение:
-----------
GeGLU — это вариант GLU (Gated Linear Unit), где «шлюз» реализован через GELU-активацию,
а затем поэлементно перемножается с другим линейным преобразованием. Такой gating-механизм позволяет повысить
выразительность MLP-блока и ускорить обучение, что подтверждено экспериментами на LLM (см. PaLM, LLaMA, T5).
Формула:
--------
GeGLU(x) = GELU(W_g x + b_g) ⊙ (W_u x + b_u) W_d + b_d
(здесь W_g, W_u, W_d — матрицы весов; GELU применяется к одной ветке, ⊙ — поэлементное умножение)
Структура блока:
----------------
1. gate = GELU(Linear_gate(x)) # ветка gating-а, shape [batch, seq, 4×emb]
2. up = Linear_up(x) # ветка передачи, shape [batch, seq, 4×emb]
3. out = gate * up # поэлементно, реализует динамическую фильтрацию информации
4. out = Linear_down(out) # проекция обратно в исходное пространство
5. out = Dropout(out) # регуляризация
Основные преимущества:
----------------------
- Позволяет эффективно обучать глубокие трансформеры (см. PaLM, LLaMA).
- Обеспечивает плавные градиенты за счёт GELU и gating-эффекта.
- Используется во многих современных LLM вместо обычных FFN или простых GLU.
Аргументы конструктора:
-----------------------
emb_size : int
Размер эмбеддинга (input и output).
dropout : float, по умолчанию 0.1
Dropout к финальному выходу (примерно 0.1-0.2 для регуляризации).
Пример использования:
---------------------
>>> geglu = GeGLU(emb_size=512, dropout=0.1)
>>> x = torch.randn(8, 16, 512)
>>> y = geglu(x)
>>> print(y.shape) # torch.Size([8, 16, 512])
Литература:
-----------
- Shazeer N., "GLU Variants Improve Transformer", 2020: https://arxiv.org/abs/2002.05202
- PaLM: https://arxiv.org/abs/2204.02311
- LLaMA: https://arxiv.org/abs/2302.13971
- T5: https://arxiv.org/abs/1910.10683
"""
def __init__(self, emb_size: int, dropout: float = 0.1):
"""
Инициализация блока GeGLU.
Создаёт три последовательных линейных слоя и задаёт GELU в качестве активации для ветки gating,
а также финальный dropout. Все размеры согласованы так, чтобы реализовать формулу GeGLU (см. описание класса).
Аргументы:
----------
emb_size : int
Размерность входного и выходного скрытого пространства (hidden size).
Данная величина определяет размерность эмбеддинга для всех внутренних вычислений.
Обычно равна размеру скрытого слоя трансформера.
dropout : float, по умолчанию 0.1
Вероятность отключения нейронов после выхода из блока (регуляризация).
Рекомендуемое значение: 0.1 (или чуть больше для небольших моделей).
Внутри:
-------
- self._gate: Linear слой размерности [emb_size, 4 * emb_size], ветка gating (проходит через GELU)
- self._up: Linear слой размерности [emb_size, 4 * emb_size], ветка передачи ("пропускная")
- self._down: Linear слой сжатия обратно к emb_size
- self._activation: Активация GELU для gating-ветки
- self._dropout: Dropout для выходного тензора
Пример:
-------
>>> block = GeGLU(emb_size=256, dropout=0.1)
>>> print(block)
"""
super().__init__()
self._gate = nn.Linear(emb_size, 4 * emb_size)
self._up = nn.Linear(emb_size, 4 * emb_size)
self._down = nn.Linear(4 * emb_size, emb_size)
self._activation = GELU()
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через блок GeGLU.
Для входного тензора скрытых состояний x реализует последовательность операций:
1. Gating-ветка: линейное преобразование → GELU-активация
2. Пропускная ветка: линейное преобразование
3. Поэлементное умножение результатов обеих веток (gating)
4. Проекция через Linear обратно к emb_size
5. Dropout результата для регуляризации
Математически:
--------------
gate = GELU(W_g·x + b_g)
up = W_u·x + b_u
out = gate * up
out = W_d·out + b_d
out = Dropout(out)
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size]
(или любой совместимой формы, где последняя ось — emb_size).
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_len, emb_size], прошедший через структуру GeGLU.
Пример:
-------
>>> y = geglu(x)
>>> print(y.shape) # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Ветка gating строит masк для динамической фильтрации информации.
- Такой тип блока эффективно используется как замена обычного FFN в современных LLM.
"""
gate_out = self._gate(x) # [batch, seq, 4*emb]
activation_out = self._activation(gate_out) # [batch, seq, 4*emb]
up_out = self._up(x) # [batch, seq, 4*emb]
out = up_out * activation_out # поэлементное!
out = self._down(out) # [batch, seq, emb]
return self._dropout(out)

View File

@@ -0,0 +1,188 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.multi_query_attention import MultiQueryAttention
from llm.core.rms_norm import RMSNorm
from llm.core.geglu import GeGLU
class GemmaDecoder(nn.Module):
"""
GemmaDecoder — декодерный блок архитектуры Gemma (Google DeepMind, 2024).
Назначение:
-----------
Данный блок реализует одну «ячейку» декодерного стека в модели Gemma. Архитектура схожа с современными LLM (Llama/Mistral),
но имеет уникальные особенности attention и feed-forward слоёв, соответствующие спецификации Gemma.
Архитектурные компоненты:
-------------------------
- LayerNorm или RMSNorm
- Multi-head self-attention (обычно Multi-Query Attention)
- Skip connection (остаточное сложение)
- Feed-forward блок (может включать SwiGLU, GeGLU или классический FFN)
- Повторная нормализация
- Dropout (регуляризация на уровне attention и feed-forward)
Алгоритм прямого прохода:
-------------------------
1. norm1_out = LayerNorm(x)
2. attention_out = Attention(norm1_out, ...)
3. resid1 = attention_out + x
4. norm2_out = LayerNorm(resid1)
5. ffn_out = FeedForward(norm2_out)
6. output = ffn_out + resid1
Теоретические детали:
---------------------
- В Gemma используются техники оптимизации памяти и ускорения инференса (например, shared K/V-головы, Rope, кастомные FFN).
- Поддержка кэширования attention для ускорения генерации (KV cache).
- Блок проектирован для использования в стеке, повторяется N раз во всей LLM.
Аргументы конструктора:
----------------------
num_q_heads : int
Число голов query (Query Heads) для attention.
num_kv_heads : int
Число ключевых/значенческих голов (Key/Value Heads).
emb_size : int
Размерность скрытого пространства (embedding dim).
head_size : int
Размерность одной attention-головы.
max_seq_len : int
Максимальная длина последовательности (ограничение на causal mask).
dropout : float, optional
Dropout для регуляризации (примерно 0.00.1).
rope : RoPE, optional
Позиционное кодирование Rotary Position Embedding.
Пример использования:
---------------------
>>> decoder = GemmaDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... dropout=0.1,
... rope=rope_obj
... )
>>> x = torch.randn(2, 24, 256)
>>> out, cache = decoder(x, mask=None, use_cache=True, cache=None)
>>> print(out.shape) # torch.Size([2, 24, 256])
Литература и ссылки:
--------------------
- Gemma (официальный релиз): https://ai.google.dev/gemma
- Gemma paper: https://arxiv.org/abs/2403.07794
- Rotary Embedding: https://arxiv.org/abs/2104.09864
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
"""
def __init__(self,
num_q_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор слоя GemmaDecoder.
Производит инициализацию всех подслоёв (нормализация, multi-head или multi-query attention, feed-forward блок, Dropout)
согласно архитектуре декодера Gemma. Обеспечивает поддержку rotary-позиционирования, обучения и inference с caching.
Аргументы:
----------
num_q_heads : int
Количество query-голов в attention (определяет степень параллелизма внимания).
emb_size : int
Размер пространства эмбеддинга (embedding dim, input/output размерность слоя).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимальная длина последовательности, для которой поддерживается attention и маскирование.
rope : RoPE
Объект для rotary positional encoding (позиционное кодирование для attention).
dropout : float, default=0.1
Dropout после attention и feed-forward для регуляризации (обычно 0.00.1).
Внутри:
-------
- Инициализируются все слои norm, attention, rope, FFN, остаточные соединения.
- Строится causal-маска автоагрессивного attention (если требуется).
- Гибко поддерживает работу как на training, так и для быстрых inference/генерации.
Пример:
-------
>>> decoder = GemmaDecoder(
... num_q_heads=8, emb_size=512, head_size=64, max_seq_len=1024, rope=rope_obj, dropout=0.05
... )
"""
super().__init__()
self._heads = MultiQueryAttention(
num_q_heads=num_q_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
rope=rope,
dropout=dropout
)
self._ff = GeGLU(emb_size=emb_size, dropout=dropout)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через GemmaDecoder.
Последовательно реализует:
- Нормализацию входа (обычно RMSNorm или LayerNorm)
- Self-attention (multi-query или multi-head, с опциональной маской и кэшем)
- Остаточное сложение (skip connection)
- Вторую нормализацию
- Feed-Forward-блок (например, GeGLU/SwiGLU)
- Ещё одно residual сложение
Поддерживает autoregressive режим с caching (KV-слоты attention для ускорения генерации).
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_length, emb_size].
mask : torch.Tensor, optional
Attention mask (например, causal или padding mask). Если None, используется встроенная causal mask.
use_cache : bool, по умолчанию True
Если True — возвращается кэш KV для ускорения autoregressive генерации.
cache : list, optional
Кэш предыдущих ключей/значений attention (если используется при инференсе).
Возвращает:
-----------
Tuple[torch.Tensor, cache]:
- Выход декодера с той же формой [batch_size, seq_length, emb_size]
- Кэш attention (если use_cache=True), иначе None
Пример:
-------
>>> out, new_cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- mask используется для ограничения внимания (напр., каузальный режим GPT/LLM).
- Для ускорения в режиме генерации рекомендуется использовать use_cache=True + передавать cache.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

View File

@@ -4,7 +4,7 @@ from .feed_forward import FeedForward
from .multi_head_attention import MultiHeadAttention
class Decoder(nn.Module):
class GptDecoder(nn.Module):
"""
Decoder базовый transformer decoder block (pre-LN), классический строительный блок современных языковых моделей.
@@ -94,7 +94,13 @@ class Decoder(nn.Module):
self._norm1 = nn.LayerNorm(emb_size)
self._norm2 = nn.LayerNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
use_cache: bool = False,
cache: list = None,
attention_mask=None
) -> tuple:
"""
Один прямой проход через Transformer decoder block.
@@ -117,10 +123,16 @@ class Decoder(nn.Module):
- Применяем FFN к нормализованному результату (layernorm)
- Добавляем residual-связь (ffn + предыдущий выход)
"""
# Self-Attention блок
attention, _ = self._heads(x, mask, use_cache=False, cache=None)
attention, kv_caches = self._heads(x, attention_mask, use_cache=use_cache, cache=cache)
out = self._norm1(attention + x)
# FeedForward блок
ffn_out = self._ff(out)
return self._norm2(ffn_out + out)
result = self._norm2(ffn_out + out)
if use_cache:
return (result, kv_caches)
else:
return (result, None)

View File

@@ -0,0 +1,211 @@
from torch import nn
import torch
import torch.nn.functional as F
from llm.core.rope import RoPE
from llm.core.group_query_attention import GroupedQueryAttention
from llm.core.moe import MoE
from llm.core.rms_norm import RMSNorm
class MixtralDecoder(nn.Module):
"""
MixtralDecoder — декодерный блок для Mixtral/MoE-трансформеров (см. Mixtral 8x7B, Mistral v0.2 и др.).
Назначение:
-----------
MixtralDecoder реализует один модульный слой глубокой трансформерной архитектуры с Mixture-of-Experts (MoE) Feed-Forward Network и Grouped Query Attention (GQA).
Поддерживает разреженную активацию и масштабируемое количество экспертов, оптимально для больших LLM.
Архитектура блока:
------------------
- RMSNorm -> Grouped Query Attention (GQA)
- skip-connection
- RMSNorm -> MoE (SwiGLU-эксперты)
- skip-connection
Для входа `x` проходит:
1. norm1_out = RMSNorm(x)
2. attention, kv_caches = GQA(norm1_out, ...)
3. out = attention + x # residual connection
4. norm2_out = RMSNorm(out)
5. ffn_out = MoE(norm2_out)
6. return (ffn_out + out, kv_caches)
Теоретическая мотивация:
------------------------
- Использование MoE (см. https://arxiv.org/abs/1701.06538) позволяет кратно увеличивать capacity без роста затрат на ff-часть.
- Grouped Query Attention эффективно масштабирует self-attention для больших моделей (см. Mistral, Llama 2/3).
- RMSNorm (Root Mean Square LayerNorm) стабилизирует градиенты и память.
- Является строительным блоком для стека декодеров в Mixtral-моделях (см. Mixtral, Mistral, LLaMA).
Аргументы конструктора:
----------------------
num_q_heads : int
Число query-голов в attention.
num_kv_heads : int
Число key-value голов (группировка ключей/values).
emb_size : int
Скрытый размер эмбеддинга.
head_size : int
Размерность одной головы (emb_size // num_q_heads).
max_seq_len : int
Максимальная поддерживаемая длина последовательности.
num_experts : int
Количество «экспертов» (MoE).
top_k_experts : int
Сколько одновременно экспертов активируется для одного токена.
window_size : int
Размер окна внимания (используется для efficient attention).
rope : RoPE
Реализация позиционного кодирования RoPE.
dropout : float
Вероятность Dropout для регуляризации.
Пример использования:
---------------------
>>> decoder = MixtralDecoder(... параметры ...)
>>> x = torch.randn(batch, seq, emb_size)
>>> out, cache = decoder(x, mask=None, use_cache=True)
>>> out.shape
Литература и ссылки:
--------------------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Shazeer et al., “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Mistral paper: https://arxiv.org/abs/2310.06825
- GQA: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self,
num_q_heads: int,
num_kv_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
num_experts: int,
top_k_experts: int,
window_size: int,
rope: RoPE,
dropout: float = 0.1
):
"""
Конструктор декодерного блока MixtralDecoder.
Осуществляет инициализацию всех под-компонентов слоя: Attention (Grouped Query Attention), MoE (Mixture-of-Experts, SwiGLU)
и нормализации (RMSNorm). Позволяет гибко настраивать архитектуру под специфику задач и размеры LLM.
Аргументы:
----------
num_q_heads : int
Количество голов внимания (queries) в механизме GroupedQueryAttention.
Чем больше — тем тоньше дискретизация внимания по подпространствам признаков.
num_kv_heads : int
Количество групп ключей/значений (key-value heads) для GQA.
Позволяет балансировать производительность и память.
emb_size : int
Размерность эмбеддингового пространства внутри слоя (hidden).
head_size : int
Размерность одной attention-головы. Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина токенизированной последовательности.
num_experts : int
Количество экспертов в слое MoE (размер пула SwiGLU-экспертов).
top_k_experts : int
Сколько экспертов по роутингу активируется на 1 токен (разреженность — эффективная экономия вычислений).
window_size : int
Размер окна для attention (может использоваться для ограничения receptive field, как в Mistral).
rope : RoPE
Объект позиционного кодирования RoPE (Rotary Positional Embedding), необходим для архитектуры внимания.
dropout : float, по умолчанию 0.1
Вероятность зануляции выходных значений для регуляризации и борьбы с переобучением.
Пример:
-------
>>> decoder = MixtralDecoder(
... num_q_heads=8,
... num_kv_heads=2,
... emb_size=256,
... head_size=32,
... max_seq_len=1024,
... num_experts=4,
... top_k_experts=2,
... window_size=128,
... rope=rope_module,
... dropout=0.05
... )
"""
super().__init__()
self._heads = GroupedQueryAttention(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
window_size=window_size,
rope=rope,
dropout=dropout
)
self._ff = MoE(
emb_size=emb_size,
num_experts=num_experts,
top_k_experts=top_k_experts,
dropout=dropout
)
self._norm1 = RMSNorm(emb_size)
self._norm2 = RMSNorm(emb_size)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = True, cache: list = None) -> torch.Tensor:
"""
Прямой проход (forward) через декодерный блок MixtralDecoder.
Данный метод реализует последовательную обработку входных скрытых состояний (x) через:
- нормализацию (RMSNorm),
- attention-модуль (Grouped Query Attention) с опциональным применением маски и кэша ключей/значений для ускорения инференса,
- остаточное сложение (residual connection),
- повторную нормализацию,
- feed-forward блок на основе Mixture-of-Experts (MoE),
- финальное остаточное сложение.
Аргументы:
----------
x : torch.Tensor
Входной скрытый тензор формы [batch_size, seq_len, emb_size] — результат эмбеддинга токенов либо предыдущего слоя.
mask : torch.Tensor, optional
(Необязательно) Маска внимания для ограничения области self-attention (например, для автоперемешивания или causal-LLM-моделей).
use_cache : bool, по умолчанию True
Если True — сохраняет кэш ключей/значений attention для ускорения авторегрессии (инференса).
cache : list, optional
(Необязательно) Предварительно вычисленный кеш attention (для ускорения генерации длинного текста).
Возвращает:
-----------
Tuple[torch.Tensor, Any]:
- Первый элемент: скрытый тензор выхода слоя с той же формой, что вход (последовательный residual из attention и MoE-блока).
- Второй элемент: обновлённый кэш attention (если use_cache=True), иначе None.
Пример:
-------
>>> out, cache = decoder(x, mask=att_mask, use_cache=True, cache=old_cache)
>>> out.shape # [batch_size, seq_len, emb_size]
Примечания:
-----------
- Для autoregressive-генерации (GPT-like режимов) следует передавать mask и использовать use_cache=True.
- Реализация поддерживает произвольные батчи и длины последовательностей, в пределах max_seq_len слоя.
- Модуль MixtralDecoder обычно используется в виде стека (несколько подряд) внутри крупной LLM.
"""
norm1_out = self._norm1(x)
attention, kv_caches = self._heads(norm1_out, mask, use_cache=use_cache, cache=cache)
out = attention + x
norm2_out = self._norm2(out)
ffn_out = self._ff(norm2_out)
if use_cache is True:
return (ffn_out + out, kv_caches)
else:
return (ffn_out + out, None)

229
llm/src/llm/core/moe.py Normal file
View File

@@ -0,0 +1,229 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.swi_glu import SwiGLU
class MoE(nn.Module):
"""
MoE (Mixture of Experts) — слой «смеси экспертов» для современных трансформерных архитектур с разреженной активацией.
Назначение:
-----------
Класс реализует слой разреженного условного вычисления для увеличения capacity трансформеров без роста вычислительных затрат.
Для каждого токена из последовательности выбирается (с помощью роутера) наиболее подходящее подмножество экспертов (малых нейросетей).
Итоговый выход формируется как взвешенная сумма откликов экспертов, выбранных для данного токена.
Архитектурная схема:
---------------------
- Для каждого входного токена `x` роутер (обычно один Linear-слой) предсказывает skor, насколько каждый из `num_experts` релевантен.
- Для каждого токена выбираются top_k_experts с максимальными skor; только они обрабатывают этот токен.
- Каждый эксперт здесь представлен отдельным экземпляром блока `SwiGLU` (может быть любая небольшая feed-forward сеть).
- Выход каждого эксперта умножается на индивидуальный вес (softmax по skor) — агрегируется взвешенная сумма.
- Dropout применяется к итоговому выходу.
Математика (коротко):
---------------------
Пусть X ∈ R^{BxSxD} — вход,
E — число экспертов,
K — число активируемых экспертов на токен.
r(x) = softmax(W_r x) — роутинг-логиты, top-K берём индексы и веса.
Для каждого токена:
y_j = Expert_j(x)
y = sum_j(w_j * y_j), где j пробегает по выбранным экспертам
Output: Y ∈ R^{BxSxD}
Аргументы конструктора:
----------------------
emb_size : int
Размерность входных/выходных векторов (обычно совпадает с embedding модели).
num_experts : int
Общее число экспертов внутри слоя MoE.
top_k_experts : int
Сколько экспертов активировать и агрегировать на каждом токене (обычно 2-8).
dropout : float, по умолчанию 0.1
Dropout к выходу агрегатора.
Пример использования:
---------------------
>>> moe = MoE(emb_size=512, num_experts=8, top_k_experts=2, dropout=0.1)
>>> x = torch.randn(4, 16, 512)
>>> y = moe(x)
>>> y.shape # torch.Size([4, 16, 512])
Литература:
-----------
- Shazeer, N. et al. “Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer”, 2017. https://arxiv.org/abs/1701.06538
- Fedus, W., Zoph, B., & Shazeer, N. “Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity”, 2021. https://arxiv.org/abs/2101.03961
- Mistral/Mixtral: https://mistral.ai/news/mixtral-of-experts/
"""
def __init__(
self,
emb_size: int,
num_experts: int,
top_k_experts: int,
dropout: float = 0.1,
):
"""
Конструктор слоя MoE (Mixture of Experts).
Позволяет создать слой, состоящий из набора экспертов (например, отдельных небольших feedforward-нейросетей) и роутера,
который будет для каждого токена определять наиболее релевантных экспертов.
Часть экспертов (top_k_experts) активируется для каждого токена, остальные — пропускаются.
Аргументы:
----------
emb_size : int
Размерность входных и выходных векторов (embedding size).
Определяет, над каким пространством признаков будет работать роутер и эксперты.
Например, если скрытый размер слоя трансформера 512, сюда нужно передать 512.
num_experts : int
Общее количество экспертов в слое MoE.
Чем больше экспертов — тем больше capacity у модели, но тем выше требования к RAM/VRAM при обучении.
Пример: 8, 16, 32, 64.
top_k_experts : int
Сколько экспертов одновременно будет обрабатывать каждый токен.
Обычно 28. Меньшее значение — выше разреженность, больше экономия вычислений.
dropout : float, по умолчанию 0.1
Вероятность зануления значений на выходе после агрегации откликов экспертов.
Используется для регуляризации (борьбы с переобучением).
Пример:
-------
>>> moe = MoE(emb_size=256, num_experts=8, top_k_experts=2, dropout=0.1)
>>> print(moe)
MoE( ... )
Теория:
-------
Слой строит:
- Линейный роутер (Linear(emb_size, num_experts)): выдает «важность» каждого эксперта для токена.
- Список из num_experts экспертов (в данной реализации — SwiGLU-блоки).
При каждом проходе для каждого токена выбираются top_k_experts наиболее релевантных экспертов,
их ответы агрегируются взвешенной суммой (softmax по роутерным логитам).
"""
super().__init__()
if top_k_experts > num_experts:
raise ValueError(f"top_k_experts ({top_k_experts}) должен быть меньше или равен num_experts ({num_experts})!")
self._num_experts = num_experts
self._top_k_experts = top_k_experts
self._router = nn.Linear(emb_size, num_experts)
self._experts = nn.ModuleList([SwiGLU(
emb_size=emb_size,
dropout=dropout,
) for _ in range(num_experts)])
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
Прямой проход (forward) через слой MoE.
Для входной последовательности скрытых состояний (обычно из предыдущего слоя трансформера)
данный метод динамически выбирает для каждого токена топ-k наиболее релевантных экспертов с помощью роутера,
пропускает соответствующие токены через выбранных экспертов и агрегирует их результаты.
Математически:
--------------
1. Для каждого токена вычисляются логиты маршрутизатора (роутера):
router_logits = Linear(x) ∈ ^{batch, seq, num_experts}
2. Выбираются top_k экспертов (topk_indices) и соответствующие им softmax-веса (topk_weights).
3. Каждый эксперт обрабатывает только свой поднабор токенов.
4. Результат агрегируется — отклик эксперта умножается на вес, ответы суммируются для каждого токена.
5. На результат применяется dropout для регуляризации.
Аргументы:
----------
x : torch.Tensor
Трёхмерный входной тензор формы [batch_size, seq_length, emb_size],
где batch_size — размер батча, seq_length — длина последовательности, emb_size — размерность эмбеддинга.
Возвращает:
-----------
torch.Tensor :
Тензор той же формы [batch_size, seq_length, emb_size] — результат комбинирования выходов выбранных экспертов
с учетом softmax-весов маршрутизатора и dropout'а.
Пример:
-------
>>> y = moe(x)
>>> print(y.shape)
torch.Size([batch_size, seq_length, emb_size])
Примечание:
-----------
- Каждый токен чаще всего активирует только подмножество экспертов.
- Остальные эксперты вычислительно “спят”, что позволяет строить очень большие (по параметрам) модели с малым ростом затрат.
- Работа с распределением топ-к экспертов и агрегирование с весами реализовано автоматически.
"""
batch_size, seq_len, emb_size = x.shape
# 1. Пропускаем через роутер
router_logits = self._router(x) # [batch_size, seq_len, num_experts]
# 2. Отбираем топ-k экспертов для каждого токена
topk_logits, topk_indices = torch.topk(
router_logits,
k=self._top_k_experts,
dim=-1
) # topk_logits: [batch_size, seq_len, top_k]
# topk_indices: [batch_size, seq_len, top_k]
# 3. Получаем веса через softmax и нормируем
topk_weights = F.softmax(topk_logits, dim=-1) # [batch_size, seq_len, top_k]
# 4. Создаём нулевой тензор для результата
output = torch.zeros_like(x) # [batch_size, seq_len, emb_size]
# 5. Проходим по всем экспертам
for expert_id in range(self._num_experts):
# Шаг 1: Создаём маску - где находится текущий эксперт в топ-k
expert_mask = (topk_indices == expert_id) # [batch_size, seq_len, top_k]
# Шаг 2: Проверяем, выбран ли эксперт хотя бы одним токеном
if not expert_mask.any():
continue # Эксперт никем не выбран, переходим к следующему
# Шаг 3: Находим токены, которые выбрали этого эксперта
# (хотя бы в одной из top_k позиций)
token_mask = expert_mask.any(dim=-1) # [batch_size, seq_len]
# Шаг 4: Отбираем токены из x
# Отбираем токены для этого эксперта
expert_input = x[token_mask]
# Пропускаем через эксперта
# Добавляем batch dimension для SwiGLU и затем убираем
expert_output = self._experts[expert_id](
expert_input.unsqueeze(0)
).squeeze(0)
# Получаем веса для этого эксперта
# Для каждого токена может быть несколько весов (если эксперт в топ-k несколько раз)
# Но на практике каждый эксперт появляется максимум 1 раз в топ-k
# Находим веса: где expert_mask == True, берём соответствующий вес
weights_for_expert = torch.zeros(
batch_size, seq_len, device=x.device
)
# Для каждой позиции в топ-k
for k in range(self._top_k_experts):
mask_k = topk_indices[:, :, k] == expert_id
weights_for_expert[mask_k] = topk_weights[:, :, k][mask_k]
# Отбираем только веса для выбранных токенов
selected_weights = weights_for_expert[token_mask] # [num_selected_tokens]
# Перемножьте выход эксперта на веса текущего эксперта.
weighted_output = selected_weights.unsqueeze(-1) * expert_output
# Помещаем результат на своё место в выходном тензоре
output[token_mask] += weighted_output
out = self._dropout(output)
return out

View File

@@ -0,0 +1,252 @@
import torch
from torch import nn
import torch.nn.functional as F
from llm.core.rope import RoPE
class MultiQueryAttention(nn.Module):
"""
Multi-Query Attention (MQA) — быстрый и экономичный вариант self-attention для LLM.
Назначение:
-----------
Класс реализует механизм внимания (self-attention), в котором для всех Query-голов используются одни и те же Key и Value.
В классическом MultiHeadAttention (MHA) на каждый Query используется свой Key/Value. В MQA набор Key/Value общий для всех голов,
что снижает требования к памяти и ускоряет работу, что особенно важно для больших LLM на inference.
Теоретическое преимущество:
--------------------------
- Существенно экономит память на матрицы Key и Value: количество KV-голов обычно в 48 раз меньше, чем число Query-голов.
- Позволяет достигать скорости почти обычной MHA при минимальной потере точности (см. Llama, Mistral).
- Является стандартом де-факто для deployment и inference современных LLM.
Архитектурная схема:
--------------------
- Для каждого токена во входе вычисляются Q_h (отдельные для каждой Query-головы), но K и V — общие для всех.
- Attention внутри каждой головы формируется через матричный продукт соответствующей Q_h и общего K.
- Выходные вектора голов конкатенируются и проецируются обратно в emb_size.
Формулы:
--------
Q = Wq·x, K = Wk·x, V = Wv·x
(Wq — отдельные для всех Query, Wk/Wv — общие для всех голов)
Attention_h(x) = softmax(Q_h·K^T / sqrt(d_k))·V
Output = Concat_h([Attention_h(x)])·W_o
Аргументы конструктора:
-----------------------
emb_size : int
Размерность скрытого пространства (hidden size, embedding dim).
num_heads : int
Число Query-голов (обычно 832 в LLM).
kv_heads : int
Число Key/Value-голов (обычно 1, 2, 4, 8).
head_size : int, optional
Размерность одной головы (обычно emb_size // num_heads).
dropout : float, optional
Вероятность Dropout для регуляризации внимания.
Пример использования:
---------------------
>>> mqa = MultiQueryAttention(emb_size=512, num_heads=8, kv_heads=1)
>>> x = torch.randn(2, 16, 512)
>>> mask = torch.ones(2, 16, 16)
>>> out = mqa(x, mask)
>>> print(out.shape) # torch.Size([2, 16, 512])
Литература и статьи:
--------------------
- Shazeer, N., “Fast Transformer Decoding: One Write-Head Is All You Need” (MQA): https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
- Mistral: https://arxiv.org/abs/2310.06825
- PaLM/PaLM2, Mixtral, ChatGLM: практическое описание MQA.
"""
def __init__(
self,
num_q_heads: int,
emb_size: int,
head_size: int,
max_seq_len: int,
rope: RoPE = None,
dropout: float = 0.1,
):
"""
Конструктор MultiQueryAttention.
Инициализирует все слои и буферы для реализации Multi-Query Attention с общими K/V-головами и индивидуальными Q-головами.
Позволяет существенно ускорять инференс и экономить память при работе с большими языковыми моделями.
Аргументы:
----------
num_q_heads : int
Число query-голов (обычно совпадает с количеством attention heads в модели).
Определяет количество параллельных subspace для запроса.
emb_size : int
Размер скрытого пространства embedding (input/output размерность attention слоя).
head_size : int
Размерность одной attention-головы.
Обычно emb_size // num_q_heads.
max_seq_len : int
Максимально поддерживаемая длина последовательности (нужна для построения треугольной маски causal attention).
rope : RoPE, optional
Модуль для rotary positional encoding (позиционный энкодер, улучшает обобщающую способность attention).
Если None, positional encoding не применяется.
dropout : float, по умолчанию 0.1
Вероятность dropout для выходного слоя attention (регуляризация).
Внутри:
-------
- Насчитывает отдельные весовые слои для Q, общие для всех голов K/V.
- Строит causal маску для автогрессивной генерации.
- (Опционально) использует RoPE для позиционного кодирования.
- Dropout применяется после финального projection.
Пример:
-------
>>> mqa = MultiQueryAttention(emb_size=256, num_q_heads=8, head_size=32, max_seq_len=2048, rope=None, dropout=0.1)
"""
super().__init__()
self._num_q_heads = num_q_heads
self._head_size = head_size
self._max_seq_len = max_seq_len
self._rope = rope
self._q = nn.Linear(emb_size, num_q_heads * head_size)
self._k = nn.Linear(emb_size, head_size)
self._v = nn.Linear(emb_size, head_size)
# Создание causal маски
mask = torch.tril(torch.ones(max_seq_len, max_seq_len))
self.register_buffer(
"_tril_mask", mask.bool() if hasattr(torch, "bool") else mask.byte()
)
self._layer = nn.Linear(num_q_heads * head_size, emb_size)
self._dropout = nn.Dropout(dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
use_cache: bool = True,
cache: list = None,
):
"""
Прямой проход (forward) через слой MultiQueryAttention.
Реализует multi-query self-attention для входных последовательностей с оптимизацией памяти за счёт общих K/V-голов для всех Query.
Поддерживает работу с rotary positional encoding (RoPE), каузальной маской и кэшированием для ускорения генерации.
Аргументы:
----------
x : torch.Tensor
Входной тензор формы [batch_size, seq_len, emb_size] — скрытые состояния после предыдущего слоя или эмбеддинга.
mask : torch.Tensor, optional
Необязательная маска внимания (например, для padding или custom-маскировки). По умолчанию используется встроенная causal mask.
use_cache : bool, по умолчанию True
Если True, возвращает кэш ключей/значений (для autoregressive inference/generation).
cache : list, optional
(K_cache, V_cache) — предварительный кэш KV (для ускоренного инференса). Если None, кэш не используется/создаётся заново.
Возвращает:
-----------
если use_cache == True:
Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
- attention_out: [batch_size, seq_len, emb_size] — результат attention после проекции и dropout.
- (K, V): кэшированные ключи и значения (использовать для последующих forward'ов в autoregressive генерации)
если use_cache == False:
Tuple[torch.Tensor, None]
Математические шаги:
--------------------
1. Q = Wq·x; K = Wk·x; V = Wv·x # Q: индивидуальные для каждой головы, K/V — общие
2. [optional] Rotary positional encoding применяется к Q и K
3. (optional) concat c k/v cache (for autoregressive inference)
4. attention_scores = softmax(Q·K^T / sqrt(head_size), mask)
5. attention_out = attention_scores·V
6. heads сливаются и проецируются в emb_size; применяется dropout.
Пример:
-------
>>> out, cache = mqa(x, mask=attn_mask, use_cache=True, cache=prev_cache)
>>> print(out.shape) # torch.Size([batch_size, seq_len, emb_size])
Примечания:
-----------
- Для каузального режима используется треугольная маска (по умолчанию).
- Для генерации текста с cache передавайте кэш от предыдущих токенов — это ускоряет autoregressive inference.
- Внимание! Тензоры внутри cache должны иметь форму [batch, heads, seq_len, head_size].
"""
batch_size, seq_len, emb_size = x.shape
if seq_len > self._max_seq_len:
raise ValueError(
f"Длина последовательности {seq_len} превышает максимум {self._max_seq_len}"
)
# Пропустите тензор x через матрицы Wq, Wk , Wv, чтобы получить матрицы запроса, ключа и значения.
k = self._k(x) # [B, T, hs]
q = self._q(x) # [B, T, hs]
v = self._v(x) # [B, T, hs]
# Шаг 2: Изменение формы для multi-head
# [batch_size, seq_len, num_heads * head_size]
# -> [batch_size, seq_len, num_heads, head_size]
q = q.reshape(batch_size, seq_len, self._num_q_heads, self._head_size)
k = k.reshape(batch_size, seq_len, 1, self._head_size)
v = v.reshape(batch_size, seq_len, 1, self._head_size)
# 3. Transpose: [B, T, H, hs] -> [B, H, T, hs]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# Пропустите матрицы запроса и ключа через экземпляр rope, чтобы выполнить поворот.
if self._rope is not None:
# Применяем RoPE к Q и K (НЕ к V!)
q = self._rope(q) # [B, T, hs]
k = self._rope(k) # [B, T, hs]
# Если cache пришел, то объединяем кэш и одну строку из ключа и значения. Это будут новые key и value для последующих вычислений.
# 5. Кэширование (для autoregressive generation)
if cache is not None:
k_cache, v_cache = cache
k = torch.cat([k_cache, k], dim=2) # Concat по seq_len (dim=2)
v = torch.cat([v_cache, v], dim=2)
# Перемножим матрицы запроса и ключа (транспонированную), чтобы вычислить матрицу внимания.
# И разделить все значения в матрице внимания на корень из head_size.
scores = q @ k.transpose(-2, -1) / (self._head_size ** 0.5)
# Если cache пришел, то маску не накладываем. Иначе наложите на матрицу внимания треугольную маску, созданную при инициализации. Все скрытые значения должны быть приведены к минус бесконечности: float('-inf').
if cache is None:
scores = scores.masked_fill(
~self._tril_mask[:seq_len, :seq_len], float("-inf")
)
# Применить к матрице внимания (построчно) функцию Softmax.
weights = F.softmax(scores, dim=-1)
# Перемножим матрицу внимания и матрицу значения.
x_out = weights @ v # [B, T, hs]
# Измените форму тензора на batch_size × seq_len × num_heads*head_size.
# Transpose обратно и concatenate heads
x_out = x_out.transpose(1, 2) # [B, T_q, H, hs]
x_out = x_out.contiguous() # Важно для reshape!
concatenated_attention = x_out.reshape(batch_size, seq_len, self._num_q_heads * self._head_size)
# Пропустите получившийся тензор через последний линейный слой.
# 3. Проецируем в пространство эмбеддингов
projected_output = self._layer(concatenated_attention)
# 4. Применяем dropout для регуляризации
final_output = self._dropout(projected_output)
if use_cache is True:
return (final_output, (k, v))
else:
return (final_output, None)

View File

@@ -0,0 +1,3 @@
from .gemma import Gemma
__all__ = ["Gemma"]

View File

@@ -0,0 +1,346 @@
import torch
import math
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.gemma_decoder import GemmaDecoder
class Gemma(BaseModel):
"""
Gemma — языковая трансформер-модель от Google, с архитектурой, оптимизированной для open-source и research-комьюнити.
Назначение:
-----------
Модель Gemma реализует стек современных декодерных блоков (GemmaDecoder), поддерживает rotary-позиционирование, multi-query self-attention,
эффективный режим генерации (KV-cache), dropout, compact residual connections, базируется на best-practice LLM-инженерии последних лет.
Поддерживает batched-тренировку и inference, генерацию с различными стратегиями выборки (greedy, top-k, top-p), автосохранение.
Архитектурные особенности:
--------------------------
- Stack из N слоёв GemmaDecoder (attention с Multi-Query либо Grouped heads, FFN с GeGLU/SwiGLU)
- RMSNorm или LayerNorm для стабилизации
- Dropout для регуляризации
- Rotary Position Embedding (RoPE) для позиционных кодов
- Выходная проекция (linear → logits) к словарю токенов
- Полная поддержка cache для ускорения autoregressive генерации
Конфиг/Параметры конструктора:
------------------------------
config : dict
Словарь c параметрами модели:
- vocab_size : int — размер словаря
- embed_dim : int — размер скрытого (hidden) пространства
- max_position_embeddings : int — максимальная длина последовательности
- num_layers : int — количество декодерных блоков
- num_q_heads : int — количество attention голов (Queries)
- num_kv_heads : int — количество ключевых/значенческих attention голов
- dropout : float — Dropout率
- ... (доп. гиперпараметры, требуемые GemmaDecoder'ами)
Основные методы:
----------------
- forward(x, use_cache=True, cache=None): выдает батч логитов по токенам, возвращает при необходимости обновленный cache.
- generate(...): автотекстогенерация с greedy, temperature, top-k/p sampling, поддержкой кэша (ускорение inference).
- save(path)/load(path, device): сохранение и загрузка предобученных весов, параметров и состояния.
Пример:
-------
>>> config = {...} # словарь с параметрами
>>> model = Gemma(config)
>>> x = torch.randint(0, config["vocab_size"], (4, 64))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [4, 64, vocab_size]
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.8)
Литература и ссылки:
--------------------
- Gemma: https://ai.google.dev/gemma (официальная страница)
- Разработка и архитектура: https://arxiv.org/abs/2403.07794
- Rotary Embedding: https://arxiv.org/abs/2104.09864
- Multi-Query Attention: https://arxiv.org/abs/1911.02150
- Llama: https://arxiv.org/abs/2302.13971
"""
def __init__(self, config):
"""
Конструктор класса Gemma.
Позволяет создать объект языковой модели с архитектурой Gemma и
произвольной конфигурацией (гибкая поддержка разных масштабов, ширин, глубин).
Аргументы:
----------
config : dict
Словарь со всеми необходимыми гиперпараметрами и архитектурными детальями модели Gemma.
Ожидаемые ключи (группы параметров):
- vocab_size : int — размер словаря токенов (размерность входа/выхода)
- embed_dim : int — скрытый размер эмбеддинга (hidden dim)
- max_position_embeddings : int — максимальная длина последовательности
- num_layers : int — количество декодерных блоков (глубина стека)
- num_q_heads : int — число attention голов (Query heads)
- num_kv_heads : int — число голов для Key/Value (MultiQuery Attention)
- dropout : float — Dropout для регуляризации
- остальные специфичные для GemmaDecoder'ов параметры
Внутри:
-------
- Инициализируются модули эмбеддинга токенов, позиционного кодирования (RoPE) и Dropout,
стек декодеров (GemmaDecoder(...)), слой финальной нормализации и выходная проекция (linear).
- Все архитектурные параметры напрямую берутся из config.
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 4,
... "dropout": 0.1,
... }
>>> model = Gemma(config)
Примечание:
-----------
- Внимание: значения config должны быть согласованы друг с другом! Например, embed_dim должен быть кратным num_q_heads и т.д.
- Поддерживается дальнейшая кастомизация стека декодеров через ключи в config.
"""
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([GemmaDecoder(
num_q_heads=config["num_q_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через полную модель Gemma.
Трансформирует входную последовательность токенов через стек из декодерных блоков GemmaDecoder.
Возвращает логиты по всем токенам и (при необходимости) кэш attention для быстрой autoregressive-генерации.
Аргументы:
----------
x : torch.Tensor
Входной тензор shape [batch_size, seq_len], содержащий токен-IDs.
use_cache : bool, по умолчанию True
Если True — сохраняет и возвращает KV-кэш attention (ускоряет автогенерацию).
Если False — кэш не используется.
cache : list, optional
(Необязательно) Список/None: с кэшами KV-матриц для каждого слоя (для режима генерации статей/диalogов).
Возвращает:
-----------
tuple:
- logits : torch.Tensor shape [batch_size, seq_len, vocab_size]
Логиты по словарю для каждого токена (input + сколь угодно новых).
- new_cache : list или None
Обновлённый cache (если use_cache=True).
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Используется при обучении и инференсе.
- Если нужно только инференс last-token — используйте logits[:, -1, :].
- При превышении x.shape[1] > max_seq_len выдаёт ValueError.
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
Реализует generation-loop с обновлением attention-кэша для ускорения инференса.
Аргументы:
----------
x : torch.Tensor
Входной тензор с последовательностью токенов (shape [batch_size, seq_len]), который необходимо продолжить.
max_new_tokens : int
Сколько новых токенов сгенерировать (максимум).
do_sample : bool
Если True — сэмплирует следующий токен согласно распределению вероятностей (stochastic), иначе выбирает токен с максимальной вероятностью (greedy).
temperature : float, default=1.0
Параметр для шкалирования распределения вероятностей логитов. Больше 1.0 — больше случайности, меньше 1.0 — более детерминированный (жёсткий) выбор.
top_k : int, optional
Если задано — для сэмплирования учитываются только top_k наиболее вероятных токенов.
top_p : float, optional
Если задано — работают nucleus sampling: учитываются токены, суммарная вероятность которых не превышает top_p.
use_cache : bool, default=True
Если True — для ускорения использует и обновляет attention-кэши (KV-cache).
Возвращает:
-----------
torch.Tensor
Тензор shape [batch_size, seq_len + max_new_tokens] с исходными и сгенерированными токенами (token IDs).
Пример:
-------
>>> out = model.generate(
... x, max_new_tokens=20, do_sample=True, temperature=0.8, top_k=50
... )
>>> print(out.shape) # [batch_size, seq_len+20]
Примечания:
-----------
- Нельзя указывать одновременно top_k и top_p (будет выброшено исключение).
- temperature <= 0 некорректно (будет выброшено исключение).
- Поддержка cache (use_cache=True) значительно ускоряет генерацию длинных последовательностей и позволяет использовать beam search/decoding.
- Для воспроизводимых результатов установите torch.manual_seed перед генерацией.
- Метод возвращает только token_ids, если нужны logits — используйте .forward напрямую.
Литература:
-----------
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Gemma: https://arxiv.org/abs/2403.07794
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -26,7 +26,7 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict
from llm.core.base_model import BaseModel
from llm.core.decoder import Decoder
from llm.core.gpt_decoder import GptDecoder
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.positional_embeddings import PositionalEmbeddings
@@ -116,7 +116,7 @@ class GPT(BaseModel):
# head_size = emb_size // num_heads
self._decoders = nn.ModuleList(
[
Decoder(
GptDecoder(
num_heads=config["num_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_heads"],
@@ -133,7 +133,9 @@ class GPT(BaseModel):
"""Возвращает максимальную длину последовательности."""
return self._max_seq_len
def forward(self, x: torch.Tensor, attention_mask=None) -> torch.Tensor:
def forward(
self, x: torch.Tensor, attention_mask=None, use_cache: bool = True, cache: list = None
) -> tuple:
"""
Прямой проход для получения логитов по последовательности токенов.
@@ -157,33 +159,60 @@ class GPT(BaseModel):
f"Длина последовательности {x.size(1)} превышает максимальную {self._max_seq_len}"
)
# Вычисление start_pos из кэша (если кэш передан)
if cache is not None:
seq_len = 1
# Безопасно извлекаем key_cache для вычисления start_pos
if (
isinstance(cache, (list, tuple))
and len(cache) > 0
and cache[0] is not None
and isinstance(cache[0], (list, tuple))
and len(cache[0]) > 0
and cache[0][0] is not None
and isinstance(cache[0][0], (tuple, list))
and len(cache[0][0]) > 0
):
key_cache, _ = cache[0][0]
start_pos = key_cache.size(1)
else:
start_pos = 0
else:
# Без кэша работаем как раньше
start_pos = 0
seq_len = x.size(1)
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
pos_out = self._position_embeddings(x.size(1)) # [seq_len, emb_size]
pos_out = self._position_embeddings(
seq_len, start_pos=start_pos
) # [seq_len, emb_size]
# Комбинирование
out = self._dropout(
tok_out + pos_out.unsqueeze(0)
) # [batch, seq_len, emb_size]
# Стек декодеров
for decoder in self._decoders:
out = decoder(out)
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
return self._linear(out) # [batch, seq_len, vocab_size]
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
# def forward(self, input_ids, attention_mask=None):
# B, T = input_ids.size()
# pos = torch.arange(0, T, device=input_ids.device).unsqueeze(0)
#
# x = self.token_emb(input_ids) + self.pos_emb(pos)
#
# for block in self.blocks:
# x = block(x, attention_mask)
#
# x = self.ln_f(x)
# logits = self.head(x)
# return logits
logits = self._linear(out) # [batch, seq_len, vocab_size]
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
@@ -193,8 +222,9 @@ class GPT(BaseModel):
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
**kwargs, # Игнорируем остальные параметры
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
@@ -244,12 +274,24 @@ class GPT(BaseModel):
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus sampling): https://arxiv.org/abs/1904.09751
- Оригинальный GPT-2: https://cdn.openai.com/better-language-models/language-models.pdf
"""
cache = None
for _ in range(max_new_tokens):
# 1. Обрезаем вход, если последовательность слишком длинная
x_cond = x[:, -self._max_seq_len :]
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# 2. Передаем последовательность в метод forward класса GPT и полуаем логиты.
logits = self.forward(x_cond)
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
# 3. Берем логиты для последнего токена
last_logits = logits[:, -1, :] # [batch_size, vocab_size]

View File

@@ -214,6 +214,8 @@ class GPT2(BaseModel):
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.

View File

@@ -176,6 +176,8 @@ class Llama(BaseModel):
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).

View File

@@ -140,14 +140,17 @@ class Mistral(BaseModel):
else:
return (logits, None)
def generate(self,
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling

View File

@@ -0,0 +1,3 @@
from .mixtral import Mixtral
__all__ = ["Mixtral"]

View File

@@ -0,0 +1,361 @@
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from math import sqrt
from llm.core.base_model import BaseModel
from llm.core.token_embeddings import TokenEmbeddings
from llm.core.rope import RoPE
from llm.core.rms_norm import RMSNorm
from llm.core.mixtral_decoder import MixtralDecoder
class Mixtral(BaseModel):
"""
Mixtral — языковая модель с архитектурой Mixture-of-Experts на основе современных трансформеров (см. Mixtral 8x7B).
Описание:
---------
Данный класс реализует полностью функциональную LLM с блоками MixtralDecoder, которые используют разреженные Feed-Forward сети MoE (Mixture-of-Experts)
и Grouped Query Attention (GQA). Позволяет масштабировать количество параметров без экспоненциального роста вычислительных затрат благодаря активации лишь части экспертов на каждый токен.
Mixtral поддерживает автотекстогенерацию с caching, position encoding через RoPE и всё необходимое для работы и тренировки современных LLM.
Архитектурные особенности:
--------------------------
- Stack из N слоёв MixtralDecoder (каждый — MoE-блок + attention + RMSNorm).
- Dropout для регуляризации на уровне эмбеддингов и слоёв.
- Позиционные эмбеддинги реализованы через RoPE (Rotary Positional Embeddings).
- Финальная RMSNorm плюс Linear-проекция к словарю токенов.
- Поддержка автогенерации с sampling (greedy, top-k, top-p), temperature и KV-cache.
Аргументы конструктора:
----------------------
config : dict
Словарь-конфиг с основными гиперпараметрами модели:
- vocab_size : int — размер словаря токенов
- embed_dim : int — размер скрытого пространства
- max_position_embeddings : int — макс. длина последовательности
- num_layers : int — количество декодерных блоков в стеке
- num_q_heads : int — число query-голов в attention
- num_kv_heads : int — число kv-голов в attention
- num_experts : int — число MoE-экспертов
- top_k_experts : int — сколько экспертов активировать на токен
- dropout : float — вероятность Dropout
- window_size : int — размер окна внимания
Основные методы:
----------------
- forward(x, use_cache=True, cache=None) — прямой проход, поддерживает batched вход, caching.
- generate(...) — авторегрессивная генерация с разными стратегиями sampling и ускорением через cache.
- save(path)/load(path, device) — сохранение и восстановление обученной модели.
Пример:
-------
>>> config = {...} # dict с параметрами
>>> model = Mixtral(config)
>>> x = torch.randint(0, config["vocab_size"], (2, 16))
>>> logits, cache = model(x, use_cache=True)
>>> print(logits.shape) # [2, 16, vocab_size]
>>> # Генерация
>>> out = model.generate(x, max_new_tokens=20, do_sample=True, top_k=10, temperature=0.9)
Литература:
-----------
- Mixtral 8x7B: https://mistral.ai/news/mixtral-of-experts/
- Switch Transformer: https://arxiv.org/abs/2101.03961
- GShard: https://arxiv.org/abs/2006.16668
- RoPE: https://arxiv.org/abs/2104.09864
- Grouped Query Attention: https://arxiv.org/abs/2305.14236
- RMSNorm: https://arxiv.org/abs/1910.07467
"""
def __init__(self, config):
"""
Конструктор класса Mixtral.
Осуществляет инициализацию всех модулей и внутренних параметров большой языковой модели с архитектурой Mixtral/MoE.
Использует параметры из конфиг-словаря `config` для гибкой настройки модели.
Аргументы:
----------
config : dict
Словарь с основными гиперпараметрами архитектуры. Должен содержать ключи:
vocab_size (int): Размер словаря токенов.
embed_dim (int): Размер скрытого пространства (эмбеддингов).
max_position_embeddings (int): Максимальная длина токенной последовательности.
num_layers (int): Количество декодерных блоков (слоёв) в модели.
num_q_heads (int): Число query-голов (attention heads).
num_kv_heads (int): Число key-value голов (attention heads).
num_experts (int): Количество экспертов в каждом MoE-блоке.
top_k_experts (int): Сколько экспертов активируется для одного токена.
dropout (float): Dropout для регуляризации.
window_size (int): Размер окна внимания (Attention Window).
Внутри:
-------
- Инициализируются эмбеддинги токенов, позиционные эмбеддинги RoPE, Dropout.
- Строится стек из num_layers модулей MixtralDecoder с заданным количеством attention heads и экспертов.
- Финальный слой нормализации и проекция к логитам словаря (linear layer).
Пример:
-------
>>> config = {
... "vocab_size": 32000,
... "embed_dim": 512,
... "max_position_embeddings": 2048,
... "num_layers": 24,
... "num_q_heads": 8,
... "num_kv_heads": 8,
... "num_experts": 8,
... "top_k_experts": 2,
... "dropout": 0.1,
... "window_size": 256,
... }
>>> model = Mixtral(config)
Примечания:
-----------
- Конфиг модели должен быть согласован: размеры должны делиться на число голов, число экспертов и top_k_experts корректно выбраны.
- Все параметры, необходимые для построения MixtralDecoder, attention и MoE, берутся из config.
"""
super().__init__(config)
self._max_seq_len = config["max_position_embeddings"]
# Инициализация слоев
self._token_embeddings = TokenEmbeddings(
vocab_size=config["vocab_size"],
emb_size=config["embed_dim"]
)
self._position_embeddings = RoPE(
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"]
)
#self._position_embeddings = PositionalEmbeddings(
# max_seq_len=max_seq_len,
# emb_size=emb_size
#)
self._dropout = nn.Dropout(config["dropout"])
self._decoders = nn.ModuleList([MixtralDecoder(
num_q_heads=config["num_q_heads"],
num_kv_heads=config["num_kv_heads"],
emb_size=config["embed_dim"],
head_size=config["embed_dim"] // config["num_q_heads"],
max_seq_len=config["max_position_embeddings"],
num_experts=config["num_experts"],
top_k_experts=config["top_k_experts"],
window_size=config["window_size"],
rope=self._position_embeddings,
dropout=config["dropout"]
) for _ in range(config["num_layers"])])
self._norm = RMSNorm(config["embed_dim"])
self._linear = nn.Linear(config["embed_dim"], config["vocab_size"])
def forward(self, x: torch.Tensor, use_cache: bool = True, cache: list = None) -> tuple:
"""
Прямой проход (forward) через всю модель Mixtral.
Данный метод реализует трансформацию входной последовательности токенов в логиты (предсказания вероятностей токенов словаря)
с поддержкой эффективного инференса с использованием cache (KV-кэш attention для автогенерации).
Аргументы:
----------
x : torch.Tensor
Двумерный входной тензор shape [batch_size, seq_len], где каждое значение — ID токена.
use_cache : bool, по умолчанию True
Если True — в режиме генерации модель возвращает обновлённый список кэшей attention для ускорения последовательного инференса.
Если False — attention cache не используется.
cache : list, optional
(Необязательно) Список (или None) с кэшем KV attention для каждого слоя. Используется для автогенерации текста.
Возвращает:
-----------
tuple:
- logits : torch.Tensor — выходной тензор shape [batch_size, seq_len, vocab_size] — массив логитов по токенам и словарю.
- new_cache : list или None — обновлённый cache, если используется.
Пример:
-------
>>> logits, new_cache = model(x, use_cache=True, cache=None)
>>> logits.shape # [batch_size, seq_len, vocab_size]
Примечания:
-----------
- Если используется cache — эффективно для авторегрессионной генерации (token-by-token), например, при диалогах или длинной генерации.
- Если входная последовательность длиннее max_seq_len — будет выброшено исключение.
- Если нужен только логит последнего токена — используйте slice: logits[:, -1, :]
"""
# Проверка длины последовательности (только при отсутствии кэша)
if cache is None and x.size(1) > self._max_seq_len:
raise ValueError(f"Длина последовательности {x.size(1)} превышает максимальную {self.max_seq_len}")
# Эмбеддинги токенов и позиций
tok_out = self._token_embeddings(x) # [batch, seq_len, emb_size]
#pos_out = self._position_embeddings(x) # [batch, seq_len, emb_size]
# Комбинирование
out = self._dropout(tok_out) # [batch, seq_len, emb_size]
# Стек декодеров с передачей кэша
new_cache = []
for i, decoder in enumerate(self._decoders):
decoder_cache = cache[i] if cache is not None else None
decoder_result = decoder(out, use_cache=use_cache, cache=decoder_cache)
# Извлекаем результат из кортежа
if use_cache:
out, decoder_new_cache = decoder_result
new_cache.append(decoder_new_cache)
else:
out = decoder_result[0]
out = self._norm(out)
logits = self._linear(out)
# Возвращаем результат с учетом use_cache
if use_cache:
return (logits, new_cache)
else:
return (logits, None)
def generate(
self,
x: torch.Tensor,
max_new_tokens: int,
do_sample: bool,
temperature: float = 1.0,
top_k: int = None,
top_p: float = None,
use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor:
"""
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
и ускорением через attention-кэш (KV-cache, важно для inference на длинных текстах).
Аргументы:
x (torch.Tensor): Входной тензор с токенами shape [batch_size, seq_len].
max_new_tokens (int): Максимальное количество новых токенов для генерации.
do_sample (bool): Если True — вероятность/случайность (random sampling); если False — жадная генерация (argmax).
temperature (float): Температура (>0, по умолчанию 1.0); >1.0 — более случайные выборы, <1.0 — более строгие.
top_k (int, optional): top-k sampling; при сэмплировании выбираются только top_k наиболее вероятных токенов.
top_p (float, optional): nucleus (top-p) sampling; выбираются токены с накопленной вероятностью ≤ top_p.
use_cache (bool, по умолчанию True): Использовать ускорение через KV attention cache для autoregressive режима.
Возвращает:
torch.Tensor: Последовательность индексов токенов shape [batch_size, seq_len + max_new_tokens].
Исключения:
ValueError: Если x длиннее max_seq_len модели.
ValueError: Если temperature ≤ 0.
ValueError: Если одновременно заданы top_k и top_p.
ValueError: Если top_k ≤ 0.
ValueError: Если top_p не в диапазоне (0, 1].
Примеры:
>>> # Жадная генерация
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=False)
>>> # Сэмплирование с температурой
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=0.8)
>>> # Top-k sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_k=50)
>>> # Top-p (nucleus) sampling
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, top_p=0.92)
>>> # Температура + top-k
>>> out = model.generate(input_ids, max_new_tokens=16, do_sample=True, temperature=1.0, top_k=100)
Примечания:
- Одновременно использовать top_k и top_p нельзя.
- Параметры temperature, top_k, top_p работают только при do_sample=True.
- Для полного воспроизведения результата зафиксируйте seed через torch.manual_seed.
- Метод всегда возвращает только индексы токенов; для получения логитов используйте forward.
Ссылки:
- Holtzman et al., "The Curious Case of Neural Text Degeneration" (nucleus/top-p sampling): https://arxiv.org/abs/1904.09751
- Mistral: https://arxiv.org/abs/2310.06825
"""
cache = None
for _ in range(max_new_tokens):
if use_cache and cache is not None:
# Используем кэш - передаем только последний токен
x_input = x[:, -1:] # [batch_size, 1]
else:
# Первая итерация или кэш отключен - передаем всю последовательность
x_input = x
# Прямой проход с кэшем
logits, new_cache = self.forward(x_input, use_cache=use_cache, cache=cache)
# Обновляем кэш для следующей итерации
if use_cache:
cache = new_cache
last_logits = logits[:, -1, :] # [batch_size, vocab_size]
# Масштабируем логиты температурой
if temperature > 0:
logits_scaled = last_logits / temperature
else:
logits_scaled = last_logits
if do_sample == True and top_k != None:
_, topk_indices = torch.topk(logits_scaled, top_k, dim=-1)
# # Заменим все НЕ top-k логиты на -inf
masked_logits = logits_scaled.clone()
vocab_size = logits_scaled.size(-1)
# создаём маску: 1, если токен НЕ в topk_indices
mask = torch.ones_like(logits_scaled, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
mask.scatter_(1, topk_indices, False if hasattr(torch, "bool") else 0) # 0 там, где top-k индексы
masked_logits[mask.bool() if hasattr(torch, "bool") else mask.byte()] = float('-inf')
logits_scaled = masked_logits
if do_sample == True and top_p != None:
# 1. Применим softmax, чтобы получить вероятности:
probs = F.softmax(logits_scaled, dim=-1) # [B, vocab_size]
# 2. Отсортируем токены по убыванию вероятностей:
sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)
# 3. Посчитаем кумулятивную сумму вероятностей:
cum_probs = torch.cumsum(sorted_probs, dim=-1) # [B, vocab_size]
# 4. Определим маску: оставить токены, пока сумма < top_p
sorted_mask = (cum_probs <= top_p).bool() if hasattr(torch, "bool") else (cum_probs <= top_p).byte() # [B, vocab_size]
# Гарантируем, что хотя бы первый токен останется
sorted_mask[:, 0] = True if hasattr(torch, "bool") else 1
# 5. Преобразуем маску обратно в оригинальный порядок:
# Создаём полную маску из 0
mask = torch.zeros_like(probs, dtype=torch.bool if hasattr(torch, "bool") else torch.uint8)
# Устанавливаем 1 в местах нужных токенов
mask.scatter_(dim=1, index=sorted_indices, src=sorted_mask)
# 6. Зануляем логиты токенов вне топ-p:
logits_scaled[~mask] = float('-inf')
# 4. Применяем Softmax
probs = F.softmax(logits_scaled, dim=-1) # [batch_size, vocab_size]
if do_sample == True:
# 5. Если do_sample равен True, то отбираем токен случайно с помощью torch.multinomial
next_token = torch.multinomial(probs, num_samples=1) # [batch_size, 1]
else:
# 5. Если do_sample равен False, то выбираем токен с максимальной вероятностью
next_token = torch.argmax(probs, dim=-1, keepdim=True) # [batch_size, 1]
# 6. Добавляем его к последовательности
x = torch.cat([x, next_token], dim=1) # [batch_size, seq_len+1]
return x
@property
def max_seq_len(self) -> int:
return self._max_seq_len

View File

@@ -0,0 +1,60 @@
import torch
import pytest
from llm.core.geglu import GeGLU
@pytest.fixture
def geglu():
return GeGLU(emb_size=16, dropout=0.1)
def test_forward_shape(geglu):
x = torch.randn(2, 5, 16)
y = geglu(x)
assert y.shape == x.shape
def test_forward_no_batch(geglu):
x = torch.randn(1, 16)
y = geglu(x.unsqueeze(0))
assert y.shape == (1, 1, 16)
@pytest.mark.skip(reason="float16 not supported without parameter casting")
def test_forward_dtype_fp16():
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(2, 4, 8).half()
y = geglu(x)
assert y.shape == x.shape
assert y.dtype == torch.float16
def test_forward_no_dropout():
geglu = GeGLU(emb_size=4, dropout=0.0)
x = torch.randn(3, 2, 4)
y = geglu(x)
assert not torch.isnan(y).any()
assert not torch.isinf(y).any()
def test_gradient_flow(geglu):
x = torch.randn(3, 8, 16, requires_grad=True)
y = geglu(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_forward_repeatability():
torch.manual_seed(42)
geglu = GeGLU(emb_size=8, dropout=0.0)
x = torch.randn(3, 2, 8)
y1 = geglu(x)
torch.manual_seed(42)
geglu2 = GeGLU(emb_size=8, dropout=0.0)
x2 = torch.randn(3, 2, 8)
y2 = geglu2(x2)
assert torch.allclose(y1, y2, atol=1e-5)
def test_edge_small_large():
geglu = GeGLU(emb_size=2, dropout=0.0)
x = torch.randn(2, 2, 2)
y = geglu(x)
assert y.shape == x.shape
geglu = GeGLU(emb_size=256, dropout=0.0)
x = torch.randn(1, 1, 256)
y = geglu(x)
assert y.shape == x.shape

View File

@@ -0,0 +1,67 @@
import torch
import pytest
from llm.core.gemma_decoder import GemmaDecoder
from llm.core.rope import RoPE
@pytest.fixture
def gemma_decoder():
rope = RoPE(head_size=4, max_seq_len=32)
return GemmaDecoder(
num_q_heads=4,
emb_size=16,
head_size=4,
max_seq_len=32,
rope=rope,
dropout=0.1,
)
def test_forward_shape(gemma_decoder):
x = torch.randn(2, 12, 16)
out, cache = gemma_decoder(x)
assert out.shape == (2, 12, 16)
assert isinstance(cache, tuple) or cache is None
def test_forward_masked(gemma_decoder):
x = torch.randn(1, 8, 16)
mask = torch.ones(1, 8, 8, dtype=torch.bool)
out, _ = gemma_decoder(x, mask=mask)
assert out.shape == x.shape
def test_forward_with_cache_flag(gemma_decoder):
x = torch.randn(2, 7, 16)
out, cache = gemma_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 7, 16)
def test_forward_wrong_seq_len_raises(gemma_decoder):
x = torch.randn(1, 100, 16)
with pytest.raises(Exception):
gemma_decoder(x)
def test_gradient_flow(gemma_decoder):
x = torch.randn(3, 9, 16, requires_grad=True)
y, _ = gemma_decoder(x)
y.sum().backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_various_shapes(gemma_decoder):
for b, s in [(1, 1), (2, 5), (2, 32)]:
x = torch.randn(b, s, 16)
y, _ = gemma_decoder(x)
assert y.shape == (b, s, 16)
def test_forward_repeatability():
torch.manual_seed(42)
rope = RoPE(head_size=4, max_seq_len=32)
decoder = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x = torch.randn(2, 8, 16)
y1, _ = decoder(x)
torch.manual_seed(42)
decoder2 = GemmaDecoder(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=rope, dropout=0.0,
)
x2 = torch.randn(2, 8, 16)
y2, _ = decoder2(x2)
assert torch.allclose(y1, y2, atol=1e-5)

View File

@@ -4,17 +4,17 @@ Tests for decoder block.
import pytest
import torch
from llm.core.decoder import Decoder
from llm.core.gpt_decoder import GptDecoder
class TestDecoder:
class TestGptDecoder:
"""Test cases for Decoder."""
def test_initialization(self, embed_dim, num_heads):
"""Test that Decoder can be initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -32,7 +32,7 @@ class TestDecoder:
"""Test forward pass of Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -40,7 +40,7 @@ class TestDecoder:
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -50,7 +50,7 @@ class TestDecoder:
"""Test forward pass with causal mask."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -62,7 +62,7 @@ class TestDecoder:
mask = torch.tril(torch.ones(seq_len, seq_len))
# Forward pass with causal mask
output = decoder(random_embeddings, mask=mask)
output, _ = decoder(random_embeddings, attention_mask=mask)
# Check output shape
assert output.shape == random_embeddings.shape
@@ -71,14 +71,14 @@ class TestDecoder:
"""Test that residual connections are properly applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# With residual connections and layer norm, the output shouldn't be
# too different from input (in terms of scale/distribution)
@@ -92,14 +92,14 @@ class TestDecoder:
"""Test that layer normalization is applied."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
max_seq_len=max_seq_len,
)
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Check that output has reasonable statistics (due to layer norm)
# Mean should be close to 0, std close to 1 for each sequence position
@@ -114,7 +114,7 @@ class TestDecoder:
"""Test that gradients flow through Decoder."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -122,7 +122,7 @@ class TestDecoder:
)
# Forward pass
output = decoder(random_embeddings)
output, _ = decoder(random_embeddings)
# Create a dummy loss and backward pass
loss = output.sum()
@@ -139,7 +139,7 @@ class TestDecoder:
"""Test that Decoder works on correct device."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -148,7 +148,7 @@ class TestDecoder:
inputs = random_embeddings.to(device)
# Forward pass
output = decoder(inputs)
output, _ = decoder(inputs)
# Check device consistency
assert output.device == device
@@ -165,7 +165,7 @@ class TestDecoder:
for embed_dim, num_heads in test_cases:
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -174,7 +174,7 @@ class TestDecoder:
batch_size, seq_len = 2, 16
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == inputs.shape
@@ -183,7 +183,7 @@ class TestDecoder:
"""Test Decoder with different input shapes."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -191,7 +191,7 @@ class TestDecoder:
)
inputs = torch.randn(batch_size, seq_len, embed_dim)
output = decoder(inputs)
output, _ = decoder(inputs)
assert output.shape == (batch_size, seq_len, embed_dim)
@@ -199,7 +199,7 @@ class TestDecoder:
"""Test that Decoder behaves differently in train vs eval mode."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,
@@ -209,11 +209,11 @@ class TestDecoder:
# Training mode
decoder.train()
output_train = decoder(random_embeddings)
output_train, _ = decoder(random_embeddings)
# Evaluation mode
decoder.eval()
output_eval = decoder(random_embeddings)
output_eval, _ = decoder(random_embeddings)
# Outputs should be different due to dropout
assert not torch.allclose(output_train, output_eval)
@@ -222,7 +222,7 @@ class TestDecoder:
"""Test that parameters are properly initialized."""
head_size = embed_dim // num_heads
max_seq_len = 1024
decoder = Decoder(
decoder = GptDecoder(
num_heads=num_heads,
emb_size=embed_dim,
head_size=head_size,

View File

@@ -0,0 +1,80 @@
import torch
import pytest
from llm.core.mixtral_decoder import MixtralDecoder
from llm.core.rope import RoPE
@pytest.fixture
def basic_decoder():
emb_size = 16
num_q_heads = 4
num_kv_heads = 2
head_size = 4
max_seq_len = 32
num_experts = 4
top_k_experts = 2
window_size = 8
rope = RoPE(head_size=head_size, max_seq_len=max_seq_len)
return MixtralDecoder(
num_q_heads=num_q_heads,
num_kv_heads=num_kv_heads,
emb_size=emb_size,
head_size=head_size,
max_seq_len=max_seq_len,
num_experts=num_experts,
top_k_experts=top_k_experts,
window_size=window_size,
rope=rope,
dropout=0.0,
)
def test_forward_shape(basic_decoder):
x = torch.randn(2, 10, 16)
out, cache = basic_decoder(x)
assert out.shape == (2, 10, 16)
assert cache is None or isinstance(cache, (tuple, list))
def test_forward_masked(basic_decoder):
x = torch.randn(3, 7, 16)
mask = torch.ones(3, 7, 7, dtype=torch.bool)
out, cache = basic_decoder(x, mask=mask)
assert out.shape == (3, 7, 16)
def test_forward_with_cache_flag(basic_decoder):
x = torch.randn(2, 8, 16)
out, cache = basic_decoder(x, use_cache=True, cache=None)
assert out.shape == (2, 8, 16)
assert isinstance(cache, (tuple, list)) or cache is None
def test_backprop_pass(basic_decoder):
x = torch.randn(2, 5, 16, requires_grad=True)
out, _ = basic_decoder(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_seq_too_long_raises(basic_decoder):
x = torch.randn(1, 40, 16) # seq_len > max_seq_len
with pytest.raises(Exception):
basic_decoder(x)
def test_different_config():
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=2, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=4, rope=rope, dropout=0.1
)
x = torch.randn(1, 8, 4)
out, cache = decoder(x)
assert out.shape == x.shape
def test_forward_no_dropout():
# Проверка на корректность shape при отсутствии Dropout
rope = RoPE(head_size=2, max_seq_len=12)
decoder = MixtralDecoder(
num_q_heads=2, num_kv_heads=1, emb_size=4, head_size=2,
max_seq_len=12, num_experts=2, top_k_experts=1, window_size=3, rope=rope, dropout=0.0
)
x = torch.randn(2, 3, 4)
out, cache = decoder(x)
assert out.shape == x.shape

View File

@@ -0,0 +1,61 @@
import torch
import pytest
from llm.core.moe import MoE
@pytest.fixture
def moe():
# Базовая MoE для коротких тестов
return MoE(emb_size=16, num_experts=4, top_k_experts=2, dropout=0.0)
def test_forward_shape(moe):
x = torch.randn(3, 5, 16) # [batch, seq, emb]
y = moe(x)
assert y.shape == x.shape
def test_forward_grad(moe):
x = torch.randn(2, 4, 16, requires_grad=True)
y = moe(x)
(y.sum()).backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_top_k_larger_than_experts():
# top_k_experts > num_experts должно падать
with pytest.raises(ValueError):
MoE(emb_size=8, num_experts=2, top_k_experts=4)
def test_single_expert_no_error():
# один эксперт, один топ-к — модель всё ещё валидна
moe = MoE(emb_size=8, num_experts=1, top_k_experts=1)
x = torch.randn(2, 2, 8)
y = moe(x)
assert y.shape == x.shape
def test_forward_trivial_weights():
"""Проверяет, что при одинаковых весах роутера MoE возвращает усреднённое по экспертам."""
class DummyMoE(MoE):
def forward(self, x):
# Роутер отдаёт всегда единичные логиты = softmax -> uniform
self._router = torch.nn.Linear(x.size(-1), self._num_experts, bias=False)
torch.nn.init.constant_(self._router.weight, 0.0)
return super().forward(x)
moe = DummyMoE(emb_size=4, num_experts=2, top_k_experts=2)
x = torch.zeros(1, 2, 4)
y = moe(x)
assert y.shape == x.shape
def test_forward_deterministic_seed(moe):
torch.manual_seed(42)
x = torch.randn(2, 3, 16)
y1 = moe(x)
torch.manual_seed(42)
y2 = moe(x)
assert torch.allclose(y1, y2, atol=1e-5)
def test_forward_no_dropout():
"""Без dropout MoE не меняет shape и не даёт NaN."""
moe = MoE(emb_size=5, num_experts=3, top_k_experts=2, dropout=0.0)
x = torch.randn(2, 7, 5)
y = moe(x)
assert y.shape == x.shape
assert not torch.isnan(y).any()

View File

@@ -0,0 +1,71 @@
import torch
import pytest
from llm.core.multi_query_attention import MultiQueryAttention
from llm.core.rope import RoPE
@pytest.fixture
def mqa_rope():
return MultiQueryAttention(
num_q_heads=4, emb_size=16, head_size=4, max_seq_len=32, rope=RoPE(head_size=4, max_seq_len=32), dropout=0.1
)
@pytest.fixture
def mqa_no_rope():
return MultiQueryAttention(
num_q_heads=2, emb_size=8, head_size=4, max_seq_len=16, rope=None, dropout=0.0
)
def test_forward_shape(mqa_rope):
x = torch.randn(2, 10, 16)
out, cache = mqa_rope(x)
assert out.shape == (2, 10, 16)
assert isinstance(cache, tuple) and len(cache) == 2
def test_forward_masked(mqa_rope):
x = torch.randn(2, 8, 16)
mask = torch.ones(2, 8, 8, dtype=torch.bool)
out, cache = mqa_rope(x, mask=mask)
assert out.shape == (2, 8, 16)
def test_forward_cache(mqa_rope):
x = torch.randn(1, 4, 16)
# Первый вызов — кэша нет
out1, cache1 = mqa_rope(x)
# Повторяем: подаем x второй раз — теперь добавим cache
out2, cache2 = mqa_rope(x, use_cache=True, cache=cache1)
assert out2.shape == (1, 4, 16)
assert isinstance(cache2, tuple) and len(cache2) == 2
# Проверка, что длина k_cache увеличилась
assert cache2[0].shape[2] == cache1[0].shape[2] + x.shape[1] # по длине seq
def test_forward_no_rope(mqa_no_rope):
x = torch.randn(3, 6, 8)
out, _ = mqa_no_rope(x)
assert out.shape == (3, 6, 8)
def test_forward_different_batch_seq(mqa_rope):
for batch, seq in [(1, 1), (2, 5), (3, 32)]:
x = torch.randn(batch, seq, 16)
out, _ = mqa_rope(x)
assert out.shape == (batch, seq, 16)
def test_forward_raise_on_long_seq(mqa_rope):
x = torch.randn(2, 40, 16) # seq_len > max_seq_len
with pytest.raises(ValueError):
mqa_rope(x)
def test_forward_grad(mqa_rope):
x = torch.randn(2, 7, 16, requires_grad=True)
out, _ = mqa_rope(x)
y = out.sum()
y.backward()
assert x.grad is not None
assert x.grad.shape == x.shape
def test_dropout_applied():
mqa = MultiQueryAttention(num_q_heads=2, emb_size=8, head_size=4, max_seq_len=12, rope=None, dropout=0.99)
x = torch.ones(1, 3, 8)
mqa.train()
y, _ = mqa(x)
# При очень большом dropout почти всё обнуляется
assert (torch.abs(y) < 1e-5).float().mean() > 0.6 or y.sum() < 1e-2

View File

@@ -0,0 +1,56 @@
# llm/tests/models/test_gemma.py
import torch
import pytest
from llm.models.gemma.gemma import Gemma
@pytest.fixture
def config():
return {
"vocab_size": 100,
"embed_dim": 32,
"num_q_heads": 4,
"num_layers": 2,
"max_position_embeddings": 16,
"dropout": 0.0,
}
@pytest.fixture
def model(config):
return Gemma(config)
def test_forward_basic(model):
x = torch.randint(0, 100, (2, 8))
logits, cache = model(x)
assert logits.shape == (2, 8, 100)
assert isinstance(cache, list)
assert len(cache) == model._decoders.__len__()
def test_forward_with_cache(model):
x = torch.randint(0, 100, (2, 4))
logits, cache = model(x, use_cache=True)
# Второй проход с cache и одним новым токеном
x2 = torch.randint(0, 100, (2, 1))
logits2, cache2 = model(x2, use_cache=True, cache=cache)
assert logits2.shape == (2, 1, 100)
assert isinstance(cache2, list)
def test_generate_and_shape(model):
x = torch.randint(0, 100, (1, 5))
result = model.generate(x, max_new_tokens=3, do_sample=False)
assert result.shape == (1, 8)
def test_forward_sequence_too_long(model, config):
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
with pytest.raises(ValueError):
model(x)
def test_generate_with_sampling_topk(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
assert out.shape == (1, 5)
def test_generate_with_sampling_topp(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
assert out.shape == (1, 5)

View File

@@ -30,7 +30,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Check output shape
batch_size, seq_len = random_inputs.shape
@@ -45,7 +45,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass with mask
logits = model(random_inputs, attention_mask=attention_mask)
logits, _ = model(random_inputs, attention_mask=attention_mask)
# Check output shape
batch_size, seq_len = random_inputs.shape
@@ -132,7 +132,7 @@ class TestGPT:
model = GPT(gpt_config)
# Forward pass
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Create a dummy loss and backward pass
targets = torch.randint(0, gpt_config["vocab_size"], random_inputs.shape)
@@ -157,7 +157,7 @@ class TestGPT:
inputs = random_inputs.to(device)
# Forward pass
logits = model(inputs)
logits, _ = model(inputs)
# Check device consistency
assert logits.device == device
@@ -197,7 +197,7 @@ class TestGPT:
batch_size, seq_len = 2, 16
inputs = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
logits = model(inputs)
logits, _ = model(inputs)
expected_shape = (batch_size, seq_len, config["vocab_size"])
assert logits.shape == expected_shape
@@ -208,7 +208,7 @@ class TestGPT:
model = GPT(gpt_config)
inputs = torch.randint(0, gpt_config["vocab_size"], (batch_size, seq_len))
logits = model(inputs)
logits, _ = model(inputs)
expected_shape = (batch_size, seq_len, gpt_config["vocab_size"])
assert logits.shape == expected_shape
@@ -219,11 +219,11 @@ class TestGPT:
# Training mode
model.train()
output_train = model(random_inputs)
output_train, _ = model(random_inputs)
# Evaluation mode
model.eval()
output_eval = model(random_inputs)
output_eval, _ = model(random_inputs)
# Outputs should be different due to dropout
assert not torch.allclose(output_train, output_eval)
@@ -271,7 +271,7 @@ class TestGPT:
"""Test that GPT output has proper distribution."""
model = GPT(gpt_config)
logits = model(random_inputs)
logits, _ = model(random_inputs)
# Logits should not have extreme values
assert logits.abs().max() < 100

View File

@@ -0,0 +1,57 @@
import torch
import pytest
from llm.models.mixtral.mixtral import Mixtral
@pytest.fixture
def config():
return {
"vocab_size": 100,
"embed_dim": 32,
"num_q_heads": 4,
"num_kv_heads": 2,
"num_layers": 2,
"max_position_embeddings": 16,
"window_size": 8,
"dropout": 0.0,
"num_experts": 4,
"top_k_experts": 2,
}
@pytest.fixture
def model(config):
return Mixtral(config)
def test_forward_basic(model):
x = torch.randint(0, 100, (2, 8))
logits, cache = model(x)
assert logits.shape == (2, 8, 100)
assert isinstance(cache, list)
assert len(cache) == model._decoders.__len__()
def test_forward_with_cache(model):
x = torch.randint(0, 100, (2, 4))
logits, cache = model(x, use_cache=True)
x2 = torch.randint(0, 100, (2, 1))
logits2, cache2 = model(x2, use_cache=True, cache=cache)
assert logits2.shape == (2, 1, 100)
assert isinstance(cache2, list)
def test_generate_and_shape(model):
x = torch.randint(0, 100, (1, 5))
result = model.generate(x, max_new_tokens=3, do_sample=False)
assert result.shape == (1, 8)
def test_forward_sequence_too_long(model, config):
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
with pytest.raises(ValueError):
model(x)
def test_generate_with_sampling_topk(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
assert out.shape == (1, 5)
def test_generate_with_sampling_topp(model):
x = torch.randint(0, 100, (1, 3))
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
assert out.shape == (1, 5)

View File

@@ -28,7 +28,7 @@ def test_gpt_model_creation():
input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len))
with torch.no_grad():
logits = model(input_ids)
logits, _ = model(input_ids)
assert logits.shape == (batch_size, seq_len, config["vocab_size"])
print("✅ GPT model creation and forward pass test passed")
@@ -222,7 +222,7 @@ def test_gpt_with_tokenizer():
input_ids = torch.tensor([tokens])
with torch.no_grad():
logits = model(input_ids)
logits, _ = model(input_ids)
assert logits.shape == (1, len(tokens), vocab_size)
print("✅ GPT with tokenizer integration test passed")

1344
notebooks/gemma.ipynb Normal file

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

1510
notebooks/mixstral.ipynb Normal file

File diff suppressed because it is too large Load Diff