From ddc4924a37d14dc12302bcfcc591116ff1c15729 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Wed, 22 Oct 2025 11:57:26 +0300 Subject: [PATCH] refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms. --- llm/src/llm/models/gemma/gemma.py | 11 +++++++---- llm/src/llm/models/gpt/gpt.py | 5 +++-- llm/src/llm/models/gpt/gpt2.py | 2 ++ llm/src/llm/models/llama/llama.py | 2 ++ llm/src/llm/models/mistral/mistral.py | 11 +++++++---- llm/src/llm/models/mixtral/mixtral.py | 11 +++++++---- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/llm/src/llm/models/gemma/gemma.py b/llm/src/llm/models/gemma/gemma.py index dc41bd7..94065d1 100644 --- a/llm/src/llm/models/gemma/gemma.py +++ b/llm/src/llm/models/gemma/gemma.py @@ -209,14 +209,17 @@ class Gemma(BaseModel): else: return (logits, None) - def generate(self, - x: torch.Tensor, - max_new_tokens: int, + def generate( + self, + x: torch.Tensor, + max_new_tokens: int, do_sample: bool, temperature: float = 1.0, top_k: int = None, top_p: float = None, - use_cache: bool = True + use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling. diff --git a/llm/src/llm/models/gpt/gpt.py b/llm/src/llm/models/gpt/gpt.py index 69394f6..d3d5dfc 100644 --- a/llm/src/llm/models/gpt/gpt.py +++ b/llm/src/llm/models/gpt/gpt.py @@ -193,8 +193,9 @@ class GPT(BaseModel): temperature: float = 1.0, top_k: int = None, top_p: float = None, - attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF - **kwargs, # Игнорируем остальные параметры + use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой, diff --git a/llm/src/llm/models/gpt/gpt2.py b/llm/src/llm/models/gpt/gpt2.py index 50c0f9a..2b173e3 100644 --- a/llm/src/llm/models/gpt/gpt2.py +++ b/llm/src/llm/models/gpt/gpt2.py @@ -214,6 +214,8 @@ class GPT2(BaseModel): top_k: int = None, top_p: float = None, use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша. diff --git a/llm/src/llm/models/llama/llama.py b/llm/src/llm/models/llama/llama.py index 1b98f45..dc7c53d 100644 --- a/llm/src/llm/models/llama/llama.py +++ b/llm/src/llm/models/llama/llama.py @@ -176,6 +176,8 @@ class Llama(BaseModel): top_k: int = None, top_p: float = None, use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша). diff --git a/llm/src/llm/models/mistral/mistral.py b/llm/src/llm/models/mistral/mistral.py index 1e56eea..3547292 100644 --- a/llm/src/llm/models/mistral/mistral.py +++ b/llm/src/llm/models/mistral/mistral.py @@ -140,14 +140,17 @@ class Mistral(BaseModel): else: return (logits, None) - def generate(self, - x: torch.Tensor, - max_new_tokens: int, + def generate( + self, + x: torch.Tensor, + max_new_tokens: int, do_sample: bool, temperature: float = 1.0, top_k: int = None, top_p: float = None, - use_cache: bool = True + use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling diff --git a/llm/src/llm/models/mixtral/mixtral.py b/llm/src/llm/models/mixtral/mixtral.py index a5c6133..1d8e1c9 100644 --- a/llm/src/llm/models/mixtral/mixtral.py +++ b/llm/src/llm/models/mixtral/mixtral.py @@ -222,14 +222,17 @@ class Mixtral(BaseModel): else: return (logits, None) - def generate(self, - x: torch.Tensor, - max_new_tokens: int, + def generate( + self, + x: torch.Tensor, + max_new_tokens: int, do_sample: bool, temperature: float = 1.0, top_k: int = None, top_p: float = None, - use_cache: bool = True + use_cache: bool = True, + attention_mask: torch.Tensor = None, + **kwargs ) -> torch.Tensor: """ Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling