mirror of
https://github.com/pese-git/simple-llm.git
synced 2026-01-24 05:26:02 +00:00
Optimize feed forward: improve dtype handling and layer processing
This commit is contained in:
@@ -46,12 +46,14 @@ class FeedForward(nn.Module):
|
|||||||
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
|
dropout: Вероятность dropout для регуляризации (по умолчанию: 0.1)
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.net = nn.Sequential(
|
# Первый линейный слой (расширение размерности)
|
||||||
nn.Linear(emb_size, 4 * emb_size),
|
self._layer1 = nn.Linear(emb_size, emb_size * 4)
|
||||||
nn.ReLU(),
|
# ReLU активация
|
||||||
nn.Linear(4 * emb_size, emb_size),
|
self._relu = nn.ReLU()
|
||||||
nn.Dropout(dropout)
|
# Второй линейный слой (сжатие обратно)
|
||||||
)
|
self._layer2 = nn.Linear(emb_size * 4, emb_size)
|
||||||
|
# Dropout
|
||||||
|
self._dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
@@ -63,6 +65,16 @@ class FeedForward(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
Тензор той же размерности, что и входной
|
Тензор той же размерности, что и входной
|
||||||
"""
|
"""
|
||||||
# Приводим все параметры сети к типу входного тензора
|
# Сохраняем dtype входных данных
|
||||||
self.net = self.net.to(x.dtype)
|
input_dtype = x.dtype
|
||||||
return self.net(x)
|
|
||||||
|
# Приводим веса к нужному типу если необходимо
|
||||||
|
if input_dtype != self._layer1.weight.dtype:
|
||||||
|
self._layer1 = self._layer1.to(dtype=input_dtype)
|
||||||
|
self._layer2 = self._layer2.to(dtype=input_dtype)
|
||||||
|
|
||||||
|
# Пропустим тензор x по очереди через все созданные слои
|
||||||
|
x = self._layer1(x)
|
||||||
|
x = self._relu(x)
|
||||||
|
x = self._layer2(x)
|
||||||
|
return self._dropout(x)
|
||||||
Reference in New Issue
Block a user