mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
always use sandwich norm for attention layer
This commit is contained in:
@@ -614,7 +614,6 @@ class Attention(nn.Module):
|
||||
heads = 8,
|
||||
dropout = 0.,
|
||||
causal = False,
|
||||
post_norm = False,
|
||||
rotary_emb = None
|
||||
):
|
||||
super().__init__()
|
||||
@@ -624,7 +623,6 @@ class Attention(nn.Module):
|
||||
|
||||
self.causal = causal
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -635,7 +633,7 @@ class Attention(nn.Module):
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, dim, bias = False),
|
||||
LayerNorm(dim) if post_norm else nn.Identity()
|
||||
LayerNorm(dim)
|
||||
)
|
||||
|
||||
def forward(self, x, mask = None, attn_bias = None):
|
||||
@@ -692,8 +690,7 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
return self.to_out(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||
]))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user