diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index e704e74..fb67a2d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -164,12 +164,21 @@ class MLP(nn.Module): # feedforward -def FeedForward(dim, mult = 4, dropout = 0.): +class SwiGLU(nn.Module): + """ used successfully in https://arxiv.org/abs/2204.0231 """ + def forward(self, x): + x, gate = x.chunk(2, dim = -1) + return x * F.silu(gate) + +def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False): + """ post-activation norm https://arxiv.org/abs/2110.09456 """ + inner_dim = int(mult * dim) return nn.Sequential( RMSNorm(dim), - nn.Linear(dim, inner_dim, bias = False), - nn.GELU(), + nn.Linear(dim, inner_dim * 2, bias = False), + SwiGLU(), + RMSNorm(inner_dim) if post_activation_norm else nn.Identity(), nn.Dropout(dropout), nn.Linear(inner_dim, dim, bias = False) ) diff --git a/setup.py b/setup.py index a13aeb1..99c68f8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.9', + version = '0.0.10', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',