mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
refactor(models): unify generate() signatures across all LLM architectures\n\n- Unified method signature: (x, max_new_tokens, do_sample, temperature, top_k, top_p, use_cache, attention_mask, **kwargs)\n- Added del attention_mask, kwargs in every generate() for compatibility and clean API\n- Prepared for drop-in replacement and ease of future batching/serving\n\nNo changes to core model logic or sampling algorithms.
This commit is contained in:
@@ -209,14 +209,17 @@ class Gemma(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return (logits, None)
|
return (logits, None)
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True
|
use_cache: bool = True,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
|
Авторегрессивная генерация токенов с использованием greedy, temperature, top-k и top-p sampling.
|
||||||
|
|||||||
@@ -193,8 +193,9 @@ class GPT(BaseModel):
|
|||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
attention_mask: torch.Tensor = None, # Добавляем для совместимости с HF
|
use_cache: bool = True,
|
||||||
**kwargs, # Игнорируем остальные параметры
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
|
Авторегрессивная генерация текста с поддержкой жадного поиска (greedy), вероятностного сэмплирования с температурой,
|
||||||
|
|||||||
@@ -214,6 +214,8 @@ class GPT2(BaseModel):
|
|||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k, top-p sampling и KV-кэша.
|
||||||
|
|||||||
@@ -176,6 +176,8 @@ class Llama(BaseModel):
|
|||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True,
|
use_cache: bool = True,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
Авторегрессивная генерация последовательностей на основе LLaMA (greedy, temperature, top-k, top-p/nucleus, поддержка KV-кэша).
|
||||||
|
|||||||
@@ -140,14 +140,17 @@ class Mistral(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return (logits, None)
|
return (logits, None)
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True
|
use_cache: bool = True,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||||
|
|||||||
@@ -222,14 +222,17 @@ class Mixtral(BaseModel):
|
|||||||
else:
|
else:
|
||||||
return (logits, None)
|
return (logits, None)
|
||||||
|
|
||||||
def generate(self,
|
def generate(
|
||||||
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
do_sample: bool,
|
do_sample: bool,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
top_k: int = None,
|
top_k: int = None,
|
||||||
top_p: float = None,
|
top_p: float = None,
|
||||||
use_cache: bool = True
|
use_cache: bool = True,
|
||||||
|
attention_mask: torch.Tensor = None,
|
||||||
|
**kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
Авторегрессивная генерация токенов с поддержкой greedy, temperature, top-k/top-p sampling
|
||||||
|
|||||||
Reference in New Issue
Block a user