use post-attn-branch layernorm in attempt to stabilize cross attention conditioning in decoder

This commit is contained in:
Phil Wang
2022-05-14 11:58:09 -07:00
parent 5d27029e98
commit 9faab59b23
2 changed files with 6 additions and 2 deletions

View File

@@ -1181,7 +1181,11 @@ class CrossAttention(nn.Module):
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim, bias = False),
LayerNorm(dim)
)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device