mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-21 22:04:40 +01:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
95b018374a | ||
|
|
8b5c2385b0 |
@@ -164,12 +164,21 @@ class MLP(nn.Module):
|
|||||||
|
|
||||||
# feedforward
|
# feedforward
|
||||||
|
|
||||||
def FeedForward(dim, mult = 4, dropout = 0.):
|
class SwiGLU(nn.Module):
|
||||||
|
""" used successfully in https://arxiv.org/abs/2204.0231 """
|
||||||
|
def forward(self, x):
|
||||||
|
x, gate = x.chunk(2, dim = -1)
|
||||||
|
return x * F.silu(gate)
|
||||||
|
|
||||||
|
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||||
|
""" post-activation norm https://arxiv.org/abs/2110.09456 """
|
||||||
|
|
||||||
inner_dim = int(mult * dim)
|
inner_dim = int(mult * dim)
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
RMSNorm(dim),
|
RMSNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias = False),
|
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||||
nn.GELU(),
|
SwiGLU(),
|
||||||
|
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||||
nn.Dropout(dropout),
|
nn.Dropout(dropout),
|
||||||
nn.Linear(inner_dim, dim, bias = False)
|
nn.Linear(inner_dim, dim, bias = False)
|
||||||
)
|
)
|
||||||
@@ -320,8 +329,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
# but let's just do it right
|
# but let's just do it right
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
all_masked_out = mask.any(dim = -1)
|
not_all_masked_out = mask.any(dim = -1)
|
||||||
mask = torch.cat((mask, rearrange(all_masked_out, 'b -> b 1')), dim = 1)
|
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||||
|
|
||||||
if exists(mask):
|
if exists(mask):
|
||||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||||
|
|||||||
Reference in New Issue
Block a user