mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-23 21:14:17 +00:00
Optimize feed forward: improve dtype handling and layer processing
This commit is contained in:
@@ -46,12 +46,14 @@ class FeedForward(nn.Module):
|
||||
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
|
||||
"""
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(emb_size, 4 * emb_size),
|
||||
nn.ReLU(),
|
||||
nn.Linear(4 * emb_size, emb_size),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
# Первый линейный слой (расширение размерности)
|
||||
self._layer1 = nn.Linear(emb_size, emb_size * 4)
|
||||
# ReLU активация
|
||||
self._relu = nn.ReLU()
|
||||
# Второй линейный слой (сжатие обратно)
|
||||
self._layer2 = nn.Linear(emb_size * 4, emb_size)
|
||||
# Dropout
|
||||
self._dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
@@ -63,6 +65,16 @@ class FeedForward(nn.Module):
|
||||
Returns:
|
||||
Тензор той же размерности, что и входной
|
||||
"""
|
||||
# Приводим все параметры сети к типу входного тензора
|
||||
self.net = self.net.to(x.dtype)
|
||||
return self.net(x)
|
||||
# Сохраняем dtype входных данных
|
||||
input_dtype = x.dtype
|
||||
|
||||
# Приводим веса к нужному типу если необходимо
|
||||
if input_dtype != self._layer1.weight.dtype:
|
||||
self._layer1 = self._layer1.to(dtype=input_dtype)
|
||||
self._layer2 = self._layer2.to(dtype=input_dtype)
|
||||
|
||||
# Пропустим тензор x по очереди через все созданные слои
|
||||
x = self._layer1(x)
|
||||
x = self._relu(x)
|
||||
x = self._layer2(x)
|
||||
return self._dropout(x)
|
||||
Reference in New Issue
Block a user