further prepare attention for classifier free guidance

This commit is contained in:
Phil Wang
2022-04-12 13:01:18 -07:00
parent 7647be2569
commit 74aec9d8ca

View File

@@ -3,7 +3,7 @@ import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from einops_exts import rearrange_many
from einops_exts import rearrange_many, repeat_many
# use x-clip
@@ -82,23 +82,31 @@ class Attention(nn.Module):
self.norm = RMSNorm(dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(heads, 2, dim_head))
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
n, device = x.shape[1], x.device
b, n, device = x.shape[:2], x.device
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = rearrange_many(qkv, 'b n (h d) -> b h n d')
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'h d -> b h 1 d', b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j')
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)