From f866ed7ac70c7d5b6d814fc06b7c0a9cc8813c17 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Sun, 5 Oct 2025 15:52:21 +0300 Subject: [PATCH] fix: universal logits extraction for tuple/model output in Trainer (GPT/GPT2 compatibility) --- llm/src/llm/training/trainer.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/llm/src/llm/training/trainer.py b/llm/src/llm/training/trainer.py index 5cfc7ac..966cdc0 100644 --- a/llm/src/llm/training/trainer.py +++ b/llm/src/llm/training/trainer.py @@ -53,8 +53,12 @@ class Trainer: input_ids = batch["input_ids"].to(self.device) labels = batch["labels"].to(self.device) - # Модель возвращает только логиты - logits = self.model(input_ids) + # Универсально обрабатываем выход (tuple/logits) + outputs = self.model(input_ids) + if isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs # Trainer вычисляет loss loss = self.compute_lm_loss(logits, labels) @@ -82,7 +86,11 @@ class Trainer: input_ids = batch["input_ids"].to(self.device) labels = batch["labels"].to(self.device) - logits = self.model(input_ids) + outputs = self.model(input_ids) + if isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs loss = self.compute_lm_loss(logits, labels) total_loss += loss.item()