mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix: universal logits extraction for tuple/model output in Trainer (GPT/GPT2 compatibility)
This commit is contained in:
@@ -53,8 +53,12 @@ class Trainer:
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
labels = batch["labels"].to(self.device)
|
||||
|
||||
# Модель возвращает только логиты
|
||||
logits = self.model(input_ids)
|
||||
# Универсально обрабатываем выход (tuple/logits)
|
||||
outputs = self.model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Trainer вычисляет loss
|
||||
loss = self.compute_lm_loss(logits, labels)
|
||||
@@ -82,7 +86,11 @@ class Trainer:
|
||||
input_ids = batch["input_ids"].to(self.device)
|
||||
labels = batch["labels"].to(self.device)
|
||||
|
||||
logits = self.model(input_ids)
|
||||
outputs = self.model(input_ids)
|
||||
if isinstance(outputs, tuple):
|
||||
logits = outputs[0]
|
||||
else:
|
||||
logits = outputs
|
||||
loss = self.compute_lm_loss(logits, labels)
|
||||
total_loss += loss.item()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user