optional projection out for prior network causal transformer

This commit is contained in:
Phil Wang
2022-04-22 11:16:30 -07:00
parent 59b1a77d4d
commit 46cef31c86
2 changed files with 6 additions and 3 deletions

View File

@@ -350,7 +350,8 @@ class CausalTransformer(nn.Module):
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
final_proj = True
):
super().__init__()
self.rel_pos_bias = RelPosBias(heads = heads)
@@ -363,6 +364,7 @@ class CausalTransformer(nn.Module):
]))
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
def forward(
self,
@@ -377,7 +379,8 @@ class CausalTransformer(nn.Module):
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
out = self.norm(x)
return self.project_out(out)
class DiffusionPriorNetwork(nn.Module):
def __init__(

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.34',
version = '0.0.35',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',