fix(hf-integration): handle logits as tuple in hf_adapter, convert torch.Tensor to list in hf_tokenizer.decode for decoding compatibility

This commit is contained in:
Sergey Penkovsky
2025-10-05 20:47:36 +03:00
parent 3843e64098
commit c31eed8551
2 changed files with 23 additions and 5 deletions

View File

@@ -99,7 +99,11 @@ class HFGPTAdapter(PreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Основной forward pass # Основной forward pass
logits = self.llm_model(input_ids) outputs = self.llm_model(input_ids)
if isinstance(outputs, tuple):
logits = outputs[0]
else:
logits = outputs
loss = None loss = None
if labels is not None: if labels is not None:

View File

@@ -56,10 +56,24 @@ class HFTokenizerAdapter:
add_special_tokens = kwargs.get('add_special_tokens', True) add_special_tokens = kwargs.get('add_special_tokens', True)
# Кодируем текст # Кодируем текст
input_ids = self.llm_tokenizer.encode( #input_ids = self.llm_tokenizer.encode(
text, # text,
add_special_tokens=add_special_tokens # add_special_tokens=add_special_tokens
) #)
if isinstance(text, str):
input_ids = self.llm_tokenizer.encode(
text,
add_special_tokens=add_special_tokens
)
input_ids = [input_ids] # <-- оборачиваем в batch
else:
# Список строк, батч-режим!
input_ids = [
self.llm_tokenizer.encode(
t,
add_special_tokens=add_special_tokens
) for t in text
]
# Применяем truncation # Применяем truncation
if truncation and max_length is not None and len(input_ids) > max_length: if truncation and max_length is not None and len(input_ids) > max_length: