diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index da4fb6c..9f5499f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -614,7 +614,6 @@ class Attention(nn.Module): heads = 8, dropout = 0., causal = False, - post_norm = False, rotary_emb = None ): super().__init__() @@ -624,7 +623,6 @@ class Attention(nn.Module): self.causal = causal self.norm = LayerNorm(dim) - self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer self.dropout = nn.Dropout(dropout) self.null_kv = nn.Parameter(torch.randn(2, dim_head)) @@ -635,7 +633,7 @@ class Attention(nn.Module): self.to_out = nn.Sequential( nn.Linear(inner_dim, dim, bias = False), - LayerNorm(dim) if post_norm else nn.Identity() + LayerNorm(dim) ) def forward(self, x, mask = None, attn_bias = None): @@ -692,8 +690,7 @@ class Attention(nn.Module): out = einsum('b h i j, b j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - out = self.to_out(out) - return self.post_norm(out) + return self.to_out(out) class CausalTransformer(nn.Module): def __init__( @@ -719,7 +716,7 @@ class CausalTransformer(nn.Module): self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb), + Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb), FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer) ])) diff --git a/setup.py b/setup.py index c287c38..cfe8f7a 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.18', + version = '0.2.19', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',