mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
always use sandwich norm for attention layer
This commit is contained in:
@@ -614,7 +614,6 @@ class Attention(nn.Module):
|
|||||||
heads = 8,
|
heads = 8,
|
||||||
dropout = 0.,
|
dropout = 0.,
|
||||||
causal = False,
|
causal = False,
|
||||||
post_norm = False,
|
|
||||||
rotary_emb = None
|
rotary_emb = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -624,7 +623,6 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
self.norm = LayerNorm(dim)
|
self.norm = LayerNorm(dim)
|
||||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||||
@@ -635,7 +633,7 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
self.to_out = nn.Sequential(
|
self.to_out = nn.Sequential(
|
||||||
nn.Linear(inner_dim, dim, bias = False),
|
nn.Linear(inner_dim, dim, bias = False),
|
||||||
LayerNorm(dim) if post_norm else nn.Identity()
|
LayerNorm(dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, mask = None, attn_bias = None):
|
def forward(self, x, mask = None, attn_bias = None):
|
||||||
@@ -692,8 +690,7 @@ class Attention(nn.Module):
|
|||||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||||
|
|
||||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
out = self.to_out(out)
|
return self.to_out(out)
|
||||||
return self.post_norm(out)
|
|
||||||
|
|
||||||
class CausalTransformer(nn.Module):
|
class CausalTransformer(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -719,7 +716,7 @@ class CausalTransformer(nn.Module):
|
|||||||
self.layers = nn.ModuleList([])
|
self.layers = nn.ModuleList([])
|
||||||
for _ in range(depth):
|
for _ in range(depth):
|
||||||
self.layers.append(nn.ModuleList([
|
self.layers.append(nn.ModuleList([
|
||||||
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, post_norm = normformer, rotary_emb = rotary_emb),
|
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout, rotary_emb = rotary_emb),
|
||||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout, post_activation_norm = normformer)
|
||||||
]))
|
]))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user