diff --git a/llm/src/llm/training/trainer.py b/llm/src/llm/training/trainer.py index 5cfc7ac..966cdc0 100644 --- a/llm/src/llm/training/trainer.py +++ b/llm/src/llm/training/trainer.py @@ -53,8 +53,12 @@ class Trainer: input_ids = batch["input_ids"].to(self.device) labels = batch["labels"].to(self.device) - # Модель возвращает только логиты - logits = self.model(input_ids) + # Универсально обрабатываем выход (tuple/logits) + outputs = self.model(input_ids) + if isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs # Trainer вычисляет loss loss = self.compute_lm_loss(logits, labels) @@ -82,7 +86,11 @@ class Trainer: input_ids = batch["input_ids"].to(self.device) labels = batch["labels"].to(self.device) - logits = self.model(input_ids) + outputs = self.model(input_ids) + if isinstance(outputs, tuple): + logits = outputs[0] + else: + logits = outputs loss = self.compute_lm_loss(logits, labels) total_loss += loss.item()