diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index fb6770d..681df5d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -3,7 +3,7 @@ import torch.nn.functional as F from torch import nn, einsum from einops import rearrange -from einops_exts import rearrange_many +from einops_exts import rearrange_many, repeat_many # use x-clip @@ -82,23 +82,31 @@ class Attention(nn.Module): self.norm = RMSNorm(dim) self.dropout = nn.Dropout(dropout) + self.null_kv = nn.Parameter(torch.randn(heads, 2, dim_head)) self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) self.to_out = nn.Linear(inner_dim, dim, bias = False) def forward(self, x, mask = None): - n, device = x.shape[1], x.device + b, n, device = x.shape[:2], x.device x = self.norm(x) qkv = self.to_qkv(x).chunk(3, dim = -1) q, k, v = rearrange_many(qkv, 'b n (h d) -> b h n d') + # add null key / value for classifier free guidance in prior net + + nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'h d -> b h 1 d', b = b) + k = torch.cat((nk, k), dim = -2) + v = torch.cat((nv, v), dim = -2) + q = q * self.scale sim = einsum('b h i d, b h j d -> b h i j') max_neg_value = -torch.finfo(sim.dtype).max if exists(mask): + mask = F.pad(mask, (1, 0), value = True) mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value)