mirror of
https://github.com/pese-git/llm-arch-research.git
synced 2026-01-23 21:10:54 +00:00
fix: universal logits extraction for tuple/model output in Trainer (GPT/GPT2 compatibility)
This commit is contained in:
@@ -53,8 +53,12 @@ class Trainer:
|
|||||||
input_ids = batch["input_ids"].to(self.device)
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
labels = batch["labels"].to(self.device)
|
labels = batch["labels"].to(self.device)
|
||||||
|
|
||||||
# Модель возвращает только логиты
|
# Универсально обрабатываем выход (tuple/logits)
|
||||||
logits = self.model(input_ids)
|
outputs = self.model(input_ids)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
logits = outputs[0]
|
||||||
|
else:
|
||||||
|
logits = outputs
|
||||||
|
|
||||||
# Trainer вычисляет loss
|
# Trainer вычисляет loss
|
||||||
loss = self.compute_lm_loss(logits, labels)
|
loss = self.compute_lm_loss(logits, labels)
|
||||||
@@ -82,7 +86,11 @@ class Trainer:
|
|||||||
input_ids = batch["input_ids"].to(self.device)
|
input_ids = batch["input_ids"].to(self.device)
|
||||||
labels = batch["labels"].to(self.device)
|
labels = batch["labels"].to(self.device)
|
||||||
|
|
||||||
logits = self.model(input_ids)
|
outputs = self.model(input_ids)
|
||||||
|
if isinstance(outputs, tuple):
|
||||||
|
logits = outputs[0]
|
||||||
|
else:
|
||||||
|
logits = outputs
|
||||||
loss = self.compute_lm_loss(logits, labels)
|
loss = self.compute_lm_loss(logits, labels)
|
||||||
total_loss += loss.item()
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user