optional projection out for prior network causal transformer

This commit is contained in:
Phil Wang
2022-04-22 11:16:30 -07:00
parent 59b1a77d4d
commit 46cef31c86
2 changed files with 6 additions and 3 deletions

View File

@@ -350,7 +350,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4, ff_mult = 4,
norm_out = False, norm_out = False,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0. ff_dropout = 0.,
final_proj = True
): ):
super().__init__() super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads) self.rel_pos_bias = RelPosBias(heads = heads)
@@ -363,6 +364,7 @@ class CausalTransformer(nn.Module):
])) ]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward( def forward(
self, self,
@@ -377,7 +379,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x x = ff(x) + x
return self.norm(x) out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module): class DiffusionPriorNetwork(nn.Module):
def __init__( def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.34', version = '0.0.35',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',