From d9af3dba35e228a8ce7e1c1a34f55a80d2b40e77 Mon Sep 17 00:00:00 2001 From: Sergey Penkovsky Date: Mon, 21 Jul 2025 10:07:52 +0300 Subject: [PATCH] Optimize feed forward: improve dtype handling and layer processing --- simple_llm/transformer/feed_forward.py | 30 ++++++++++++++++++-------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/simple_llm/transformer/feed_forward.py b/simple_llm/transformer/feed_forward.py index 81e7558..1d17188 100644 --- a/simple_llm/transformer/feed_forward.py +++ b/simple_llm/transformer/feed_forward.py @@ -46,12 +46,14 @@ class FeedForward(nn.Module): dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1) """ super().__init__() - self.net = nn.Sequential( - nn.Linear(emb_size, 4 * emb_size), - nn.ReLU(), - nn.Linear(4 * emb_size, emb_size), - nn.Dropout(dropout) - ) + # Первый линейный слой (расширение размерности) + self._layer1 = nn.Linear(emb_size, emb_size * 4) + # ReLU активация + self._relu = nn.ReLU() + # Второй линейный слой (сжатие обратно) + self._layer2 = nn.Linear(emb_size * 4, emb_size) + # Dropout + self._dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor): """ @@ -63,6 +65,16 @@ class FeedForward(nn.Module): Returns: Тензор той же размерности, что и входной """ - # Приводим все параметры сети к типу входного тензора - self.net = self.net.to(x.dtype) - return self.net(x) \ No newline at end of file + # Сохраняем dtype входных данных + input_dtype = x.dtype + + # Приводим веса к нужному типу если необходимо + if input_dtype != self._layer1.weight.dtype: + self._layer1 = self._layer1.to(dtype=input_dtype) + self._layer2 = self._layer2.to(dtype=input_dtype) + + # Пропустим тензор x по очереди через все созданные слои + x = self._layer1(x) + x = self._relu(x) + x = self._layer2(x) + return self._dropout(x) \ No newline at end of file