Compare commits

...

2 Commits

Author SHA1 Message Date
Phil Wang
95b018374a start using swish glu everywhere, given success of PaLM 2022-04-14 09:34:32 -07:00
Phil Wang
8b5c2385b0 better naming 2022-04-14 09:24:31 -07:00
2 changed files with 15 additions and 6 deletions

View File

@@ -164,12 +164,21 @@ class MLP(nn.Module):
# feedforward # feedforward
def FeedForward(dim, mult = 4, dropout = 0.): class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
inner_dim = int(mult * dim) inner_dim = int(mult * dim)
return nn.Sequential( return nn.Sequential(
RMSNorm(dim), RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False), nn.Linear(dim, inner_dim * 2, bias = False),
nn.GELU(), SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout), nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False) nn.Linear(inner_dim, dim, bias = False)
) )
@@ -320,8 +329,8 @@ class DiffusionPriorNetwork(nn.Module):
# but let's just do it right # but let's just do it right
if exists(mask): if exists(mask):
all_masked_out = mask.any(dim = -1) not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1) mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.9', version = '0.0.10',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',