mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
- Add Mixtral architecture implementation with MoE support (llm/src/llm/models/mixtral/mixtral.py) - Introduce generic Mixture-of-Experts (MoE) block (llm/src/llm/core/moe.py) - Create dedicated configuration files for Mixtral training and generation experiments - Register and test Mixtral support in experiment runner (run_llm_experiment.py) - Add unit tests for Mixtral API including forward, caching, and generation modes - Include Jupyter notebook mixstral.ipynb for architectural exploration and research - Ensure correct handling of torch bool masks in sampling (top-k, top-p) during generation BREAKING CHANGE: Adds new model code and test coverage, modifying experiment runner logic to register Mixtral.
58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import torch
|
|
import pytest
|
|
from llm.models.mixtral.mixtral import Mixtral
|
|
|
|
@pytest.fixture
|
|
def config():
|
|
return {
|
|
"vocab_size": 100,
|
|
"embed_dim": 32,
|
|
"num_q_heads": 4,
|
|
"num_kv_heads": 2,
|
|
"num_layers": 2,
|
|
"max_position_embeddings": 16,
|
|
"window_size": 8,
|
|
"dropout": 0.0,
|
|
"num_experts": 4,
|
|
"top_k_experts": 2,
|
|
}
|
|
|
|
@pytest.fixture
|
|
def model(config):
|
|
return Mixtral(config)
|
|
|
|
def test_forward_basic(model):
|
|
x = torch.randint(0, 100, (2, 8))
|
|
logits, cache = model(x)
|
|
assert logits.shape == (2, 8, 100)
|
|
assert isinstance(cache, list)
|
|
assert len(cache) == model._decoders.__len__()
|
|
|
|
def test_forward_with_cache(model):
|
|
x = torch.randint(0, 100, (2, 4))
|
|
logits, cache = model(x, use_cache=True)
|
|
x2 = torch.randint(0, 100, (2, 1))
|
|
logits2, cache2 = model(x2, use_cache=True, cache=cache)
|
|
assert logits2.shape == (2, 1, 100)
|
|
assert isinstance(cache2, list)
|
|
|
|
def test_generate_and_shape(model):
|
|
x = torch.randint(0, 100, (1, 5))
|
|
result = model.generate(x, max_new_tokens=3, do_sample=False)
|
|
assert result.shape == (1, 8)
|
|
|
|
def test_forward_sequence_too_long(model, config):
|
|
x = torch.randint(0, 100, (1, config["max_position_embeddings"] + 1))
|
|
with pytest.raises(ValueError):
|
|
model(x)
|
|
|
|
def test_generate_with_sampling_topk(model):
|
|
x = torch.randint(0, 100, (1, 3))
|
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_k=5)
|
|
assert out.shape == (1, 5)
|
|
|
|
def test_generate_with_sampling_topp(model):
|
|
x = torch.randint(0, 100, (1, 3))
|
|
out = model.generate(x, max_new_tokens=2, do_sample=True, top_p=0.8)
|
|
assert out.shape == (1, 5)
|