From 46cef31c86383f2339750883b3be655c56906576 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 22 Apr 2022 11:16:30 -0700 Subject: [PATCH] optional projection out for prior network causal transformer --- dalle2_pytorch/dalle2_pytorch.py | 7 +++++-- setup.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 16f8095..41cbab4 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -350,7 +350,8 @@ class CausalTransformer(nn.Module): ff_mult = 4, norm_out = False, attn_dropout = 0., - ff_dropout = 0. + ff_dropout = 0., + final_proj = True ): super().__init__() self.rel_pos_bias = RelPosBias(heads = heads) @@ -363,6 +364,7 @@ class CausalTransformer(nn.Module): ])) self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options + self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity() def forward( self, @@ -377,7 +379,8 @@ class CausalTransformer(nn.Module): x = attn(x, mask = mask, attn_bias = attn_bias) + x x = ff(x) + x - return self.norm(x) + out = self.norm(x) + return self.project_out(out) class DiffusionPriorNetwork(nn.Module): def __init__( diff --git a/setup.py b/setup.py index be472ca..fef0064 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.34', + version = '0.0.35', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',