refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms.

This commit is contained in:
Sergey Penkovsky
2025-10-22 11:57:26 +03:00
parent 92a34551b8
commit ddc4924a37
6 changed files with 28 additions and 14 deletions

View File

@@ -209,14 +209,17 @@ class Gemma(BaseModel):
else: else:
return (logits, None) return (logits, None)
def generate(self, def generate(
x: torch.Tensor, self,
max_new_tokens: int, x: torch.Tensor,
max_new_tokens: int,
do_sample: bool, do_sample: bool,
temperature: float = 1.0, temperature: float = 1.0,
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
use_cache: bool = True use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling. Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.

View File

@@ -193,8 +193,9 @@ class GPT(BaseModel):
temperature: float = 1.0, temperature: float = 1.0,
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF use_cache: bool = True,
**kwargs, # Игнорируем остальные параметры attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой, Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,

View File

@@ -214,6 +214,8 @@ class GPT2(BaseModel):
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
use_cache: bool = True, use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша. Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.

View File

@@ -176,6 +176,8 @@ class Llama(BaseModel):
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
use_cache: bool = True, use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша). Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).

View File

@@ -140,14 +140,17 @@ class Mistral(BaseModel):
else: else:
return (logits, None) return (logits, None)
def generate(self, def generate(
x: torch.Tensor, self,
max_new_tokens: int, x: torch.Tensor,
max_new_tokens: int,
do_sample: bool, do_sample: bool,
temperature: float = 1.0, temperature: float = 1.0,
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
use_cache: bool = True use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling

View File

@@ -222,14 +222,17 @@ class Mixtral(BaseModel):
else: else:
return (logits, None) return (logits, None)
def generate(self, def generate(
x: torch.Tensor, self,
max_new_tokens: int, x: torch.Tensor,
max_new_tokens: int,
do_sample: bool, do_sample: bool,
temperature: float = 1.0, temperature: float = 1.0,
top_k: int = None, top_k: int = None,
top_p: float = None, top_p: float = None,
use_cache: bool = True use_cache: bool = True,
attention_mask: torch.Tensor = None,
**kwargs
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling