mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 13:00:54 +00:00
refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms.
This commit is contained in:
@@ -209,14 +209,17 @@ class Gemma(BaseModel):
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
|
||||
|
||||
@@ -193,8 +193,9 @@ class GPT(BaseModel):
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
|
||||
**kwargs, # Игнорируем остальные параметры
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
|
||||
|
||||
@@ -214,6 +214,8 @@ class GPT2(BaseModel):
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
|
||||
|
||||
@@ -176,6 +176,8 @@ class Llama(BaseModel):
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||
|
||||
@@ -140,14 +140,17 @@ class Mistral(BaseModel):
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||
|
||||
@@ -222,14 +222,17 @@ class Mixtral(BaseModel):
|
||||
else:
|
||||
return (logits, None)
|
||||
|
||||
def generate(self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
def generate(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
max_new_tokens: int,
|
||||
do_sample: bool,
|
||||
temperature: float = 1.0,
|
||||
top_k: int = None,
|
||||
top_p: float = None,
|
||||
use_cache: bool = True
|
||||
use_cache: bool = True,
|
||||
attention_mask: torch.Tensor = None,
|
||||
**kwargs
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||
|
||||
Reference in New Issue
Block a user