mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
start using swish glu everywhere, given success of PaLM
This commit is contained in:
@@ -164,12 +164,21 @@ class MLP(nn.Module):
|
||||
|
||||
# feedforward
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
||||
class SwiGLU(nn.Module):
|
||||
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
||||
def forward(self, x):
|
||||
x, gate = x.chunk(2, dim = -1)
|
||||
return x * F.silu(gate)
|
||||
|
||||
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
nn.Linear(dim, inner_dim, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user