mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-24 13:32:08 +00:00
fix(hf-integration): handle logits as tuple in hf_adapter, convert torch.Tensor to list in hf_tokenizer.decode for decoding compatibility
This commit is contained in:
@@ -99,7 +99,11 @@ class HFGPTAdapter(PreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# Основной forward pass
|
# Основной forward pass
|
||||||
logits = self.llm_model(input_ids)
|
outputs = self.llm_model(input_ids)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
logits = outputs[0]
|
||||||
|
else:
|
||||||
|
logits = outputs
|
||||||
|
|
||||||
loss = None
|
loss = None
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -56,10 +56,24 @@ class HFTokenizerAdapter:
|
|||||||
add_special_tokens = kwargs.get('add_special_tokens', True)
|
add_special_tokens = kwargs.get('add_special_tokens', True)
|
||||||
|
|
||||||
# Кодируем текст
|
# Кодируем текст
|
||||||
input_ids = self.llm_tokenizer.encode(
|
#input_ids = self.llm_tokenizer.encode(
|
||||||
text,
|
# text,
|
||||||
add_special_tokens=add_special_tokens
|
# add_special_tokens=add_special_tokens
|
||||||
)
|
#)
|
||||||
|
if isinstance(text, str):
|
||||||
|
input_ids = self.llm_tokenizer.encode(
|
||||||
|
text,
|
||||||
|
add_special_tokens=add_special_tokens
|
||||||
|
)
|
||||||
|
input_ids = [input_ids] # <-- оборачиваем в batch
|
||||||
|
else:
|
||||||
|
# Список строк, батч-режим!
|
||||||
|
input_ids = [
|
||||||
|
self.llm_tokenizer.encode(
|
||||||
|
t,
|
||||||
|
add_special_tokens=add_special_tokens
|
||||||
|
) for t in text
|
||||||
|
]
|
||||||
|
|
||||||
# Применяем truncation
|
# Применяем truncation
|
||||||
if truncation and max_length is not None and len(input_ids) > max_length:
|
if truncation and max_length is not None and len(input_ids) > max_length:
|
||||||
|
|||||||
Reference in New Issue
Block a user