mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-20 18:24:19 +01:00
bring in attention - it is all we need
This commit is contained in:
@@ -1,7 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn, einsum
|
from torch import nn, einsum
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from einops_exts import rearrange_many
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
|
|
||||||
@@ -42,23 +44,92 @@ def freeze_model_and_make_eval_(model):
|
|||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
def FeedForward(dim, mult = 4):
|
||||||
|
inner_dim = int(mult * dim)
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.LayerNorm(dim),
|
||||||
|
nn.Linear(dim, inner_dim, bias = False),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
)
|
||||||
|
|
||||||
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
dim,
|
dim,
|
||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8
|
||||||
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim)
|
||||||
|
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||||
|
self.to_out = nn.Linear(inner_dim, dim, bias = False)
|
||||||
|
|
||||||
|
def forward(self, x, mask = None):
|
||||||
|
n, device = x.shape[1], x.device
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
qkv = self.to_qkv(x).chunk(3, dim = -1)
|
||||||
|
|
||||||
|
q, k, v = rearrange_many(qkv, 'b n (h d) -> b h n d')
|
||||||
|
|
||||||
|
q = q * self.scale
|
||||||
|
|
||||||
|
sim = einsum('b h i d, b h j d -> b h i j')
|
||||||
|
max_neg_value = -torch.finfo(sim.dtype).max
|
||||||
|
|
||||||
|
if exists(mask):
|
||||||
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
||||||
|
sim = sim.masked_fill(~mask, max_neg_value)
|
||||||
|
|
||||||
|
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
|
||||||
|
sim = sim.masked_fill(causal_mask, max_neg_value)
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -1, keepdim = True)
|
||||||
|
attn = sim.softmax(dim = -1)
|
||||||
|
|
||||||
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||||
|
|
||||||
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||||
|
return self.to_out(out)
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
depth,
|
||||||
|
dim_head = 64,
|
||||||
|
heads = 8,
|
||||||
|
ff_mult = 4,
|
||||||
|
norm_out = False
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# todo - bring in rotary embeddings or alibi
|
||||||
|
|
||||||
|
self.layers = nn.ModuleList([])
|
||||||
|
for _ in range(depth):
|
||||||
|
self.layers.append(nn.ModuleList([
|
||||||
|
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||||
|
FeedForward(dim = dim, mult = ff_mult)
|
||||||
|
]))
|
||||||
|
|
||||||
|
self.norm = nn.LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
|
||||||
):
|
):
|
||||||
return x
|
for attn, ff in self.layers:
|
||||||
|
x = attn(x) + x
|
||||||
|
x = ff(x) + x
|
||||||
|
|
||||||
|
return self.norm(x)
|
||||||
|
|
||||||
class DiffusionPrior(nn.Module):
|
class DiffusionPrior(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user