From df4dac4f5a62f40e0ec9e92a372a538f08b102cb Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 12 Apr 2022 10:23:07 -0700 Subject: [PATCH] bring in attention - it is all we need --- dalle2_pytorch/dalle2_pytorch.py | 79 ++++++++++++++++++++++++++++++-- 1 file changed, 75 insertions(+), 4 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 7dcf012..c47915a 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F from torch import nn, einsum + from einops import rearrange +from einops_exts import rearrange_many # use x-clip @@ -42,23 +44,92 @@ def freeze_model_and_make_eval_(model): # diffusion prior -class Transformer(nn.Module): +def FeedForward(dim, mult = 4): + inner_dim = int(mult * dim) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias = False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias = False) + ) + +class Attention(nn.Module): def __init__( self, *, dim, dim_head = 64, - heads = 8, - + heads = 8 ): super().__init__() + self.scale = dim_head ** -0.5 + inner_dim = dim_head * heads + + self.norm = nn.LayerNorm(dim) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) + self.to_out = nn.Linear(inner_dim, dim, bias = False) + + def forward(self, x, mask = None): + n, device = x.shape[1], x.device + + x = self.norm(x) + qkv = self.to_qkv(x).chunk(3, dim = -1) + + q, k, v = rearrange_many(qkv, 'b n (h d) -> b h n d') + + q = q * self.scale + + sim = einsum('b h i d, b h j d -> b h i j') + max_neg_value = -torch.finfo(sim.dtype).max + + if exists(mask): + mask = rearrange(mask, 'b j -> b 1 1 j') + sim = sim.masked_fill(~mask, max_neg_value) + + causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1) + sim = sim.masked_fill(causal_mask, max_neg_value) + + sim = sim - sim.amax(dim = -1, keepdim = True) + attn = sim.softmax(dim = -1) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__( + self, + *, + dim, + depth, + dim_head = 64, + heads = 8, + ff_mult = 4, + norm_out = False + ): + super().__init__() + # todo - bring in rotary embeddings or alibi + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim = dim, dim_head = dim_head, heads = heads), + FeedForward(dim = dim, mult = ff_mult) + ])) + + self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options def forward( self, x, mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings ): - return x + for attn, ff in self.layers: + x = attn(x) + x + x = ff(x) + x + + return self.norm(x) class DiffusionPrior(nn.Module): def __init__(