Optimize feed forward: improve dtype handling and layer processing

This commit is contained in:
Sergey Penkovsky
2025-07-21 10:07:52 +03:00
parent 6832978dc1
commit d9af3dba35

View File

@@ -46,12 +46,14 @@ class FeedForward(nn.Module):
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(emb_size, 4 * emb_size),
nn.ReLU(),
nn.Linear(4 * emb_size, emb_size),
nn.Dropout(dropout)
)
# Первый линейный слой (расширение размерности)
self._layer1 = nn.Linear(emb_size, emb_size * 4)
# ReLU активация
self._relu = nn.ReLU()
# Второй линейный слой (сжатие обратно)
self._layer2 = nn.Linear(emb_size * 4, emb_size)
# Dropout
self._dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor):
"""
@@ -63,6 +65,16 @@ class FeedForward(nn.Module):
Returns:
Тензор той же размерности, что и входной
"""
# Приводим все параметры сети к типу входного тензора
self.net = self.net.to(x.dtype)
return self.net(x)
# Сохраняем dtype входных данных
input_dtype = x.dtype
# Приводим веса к нужному типу если необходимо
if input_dtype != self._layer1.weight.dtype:
self._layer1 = self._layer1.to(dtype=input_dtype)
self._layer2 = self._layer2.to(dtype=input_dtype)
# Пропустим тензор x по очереди через все созданные слои
x = self._layer1(x)
x = self._relu(x)
x = self._layer2(x)
return self._dropout(x)