mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix(hf-integration): handle logits as tuple in hf_adapter, convert torch.Tensor to list in hf_tokenizer.decode for decoding compatibility
This commit is contained in:
@@ -99,7 +99,11 @@ class HFGPTAdapter(PreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# Основной forward pass
|
||||
logits = self.llm_model(input_ids)
|
||||
outputs = self.llm_model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
@@ -56,10 +56,24 @@ class HFTokenizerAdapter:
|
||||
add_special_tokens = kwargs.get('add_special_tokens', True)
|
||||
|
||||
# Кодируем текст
|
||||
input_ids = self.llm_tokenizer.encode(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens
|
||||
)
|
||||
#input_ids = self.llm_tokenizer.encode(
|
||||
# text,
|
||||
# add_special_tokens=add_special_tokens
|
||||
#)
|
||||
if isinstance(text, str):
|
||||
input_ids = self.llm_tokenizer.encode(
|
||||
text,
|
||||
add_special_tokens=add_special_tokens
|
||||
)
|
||||
input_ids = [input_ids] # <-- оборачиваем в batch
|
||||
else:
|
||||
# Список строк, батч-режим!
|
||||
input_ids = [
|
||||
self.llm_tokenizer.encode(
|
||||
t,
|
||||
add_special_tokens=add_special_tokens
|
||||
) for t in text
|
||||
]
|
||||
|
||||
# Применяем truncation
|
||||
if truncation and max_length is not None and len(input_ids) > max_length:
|
||||
|
||||
Reference in New Issue
Block a user