fix: universal logits extraction for tuple/model output in Trainer (GPT/GPT2 compatibility)

This commit is contained in:
Sergey Penkovsky
2025-10-05 15:52:21 +03:00
parent aa408e941a
commit f866ed7ac7

View File

@@ -53,8 +53,12 @@ class Trainer:
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
# Модель возвращает только логиты
logits = self.model(input_ids)
# Универсально обрабатываем выход (tuple/logits)
outputs = self.model(input_ids)
if isinstance(outputs, tuple):
logits = outputs[0]
else:
logits = outputs
# Trainer вычисляет loss
loss = self.compute_lm_loss(logits, labels)
@@ -82,7 +86,11 @@ class Trainer:
input_ids = batch["input_ids"].to(self.device)
labels = batch["labels"].to(self.device)
logits = self.model(input_ids)
outputs = self.model(input_ids)
if isinstance(outputs, tuple):
logits = outputs[0]
else:
logits = outputs
loss = self.compute_lm_loss(logits, labels)
total_loss += loss.item()