mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-23 14:34:33 +01:00
go for the multi-headed queries, one-headed key/values, proven out in AlphaCode as well as PaLM by now
This commit is contained in:
@@ -125,8 +125,9 @@ class Attention(nn.Module):
|
||||
self.norm = RMSNorm(dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(heads, 2, dim_head))
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
self.to_qkv = nn.Linear(dim, inner_dim, bias = False)
|
||||
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||
|
||||
def forward(self, x, mask = None):
|
||||
@@ -135,17 +136,17 @@ class Attention(nn.Module):
|
||||
x = self.norm(x)
|
||||
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||
|
||||
q, k, v = rearrange_many(qkv, 'b n (h d) -> b h n d')
|
||||
q = rearrange(q, 'b n (h d) -> b h n d')
|
||||
|
||||
# add null key / value for classifier free guidance in prior net
|
||||
|
||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'h d -> b h 1 d', b = b)
|
||||
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
|
||||
k = torch.cat((nk, k), dim = -2)
|
||||
v = torch.cat((nv, v), dim = -2)
|
||||
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum('b h i d, b h j d -> b h i j')
|
||||
sim = einsum('b h i d, b j d -> b h i j')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
|
||||
if exists(mask):
|
||||
@@ -159,7 +160,7 @@ class Attention(nn.Module):
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
Reference in New Issue
Block a user