mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
start a file for all attention-related modules, use attention-based upsampling in the unets in dalle-2
This commit is contained in:
125
dalle2_pytorch/attention.py
Normal file
125
dalle2_pytorch/attention.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
import torch
|
||||||
|
from torch import nn, einsum
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from einops import rearrange, repeat
|
||||||
|
|
||||||
|
class LayerNormChan(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dim,
|
||||||
|
eps = 1e-5
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.eps = eps
|
||||||
|
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||||
|
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||||
|
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||||
|
|
||||||
|
# attention-based upsampling
|
||||||
|
# from https://arxiv.org/abs/2112.11435
|
||||||
|
|
||||||
|
class QueryAndAttend(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dim,
|
||||||
|
num_queries = 1,
|
||||||
|
dim_head = 32,
|
||||||
|
heads = 8,
|
||||||
|
window_size = 3
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.scale = dim_head ** -0.5
|
||||||
|
inner_dim = dim_head * heads
|
||||||
|
self.heads = heads
|
||||||
|
self.dim_head = dim_head
|
||||||
|
self.window_size = window_size
|
||||||
|
self.num_queries = num_queries
|
||||||
|
|
||||||
|
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||||
|
|
||||||
|
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||||
|
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||||
|
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
einstein notation
|
||||||
|
b - batch
|
||||||
|
h - heads
|
||||||
|
l - num queries
|
||||||
|
d - head dimension
|
||||||
|
x - height
|
||||||
|
y - width
|
||||||
|
j - source sequence for attending to (kernel size squared in this case)
|
||||||
|
"""
|
||||||
|
|
||||||
|
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||||
|
batch, _, height, width = x.shape
|
||||||
|
|
||||||
|
is_one_query = self.num_queries == 1
|
||||||
|
|
||||||
|
# queries, keys, values
|
||||||
|
|
||||||
|
q = self.queries * self.scale
|
||||||
|
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||||
|
|
||||||
|
# similarities
|
||||||
|
|
||||||
|
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||||
|
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||||
|
|
||||||
|
# unfold the similarity scores, with float(-inf) as padding value
|
||||||
|
|
||||||
|
mask_value = -torch.finfo(sim.dtype).max
|
||||||
|
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||||
|
sim = F.unfold(sim, kernel_size = wsz)
|
||||||
|
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||||
|
|
||||||
|
# rel pos bias
|
||||||
|
|
||||||
|
sim = sim + self.rel_pos_bias
|
||||||
|
|
||||||
|
# numerically stable attention
|
||||||
|
|
||||||
|
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||||
|
attn = sim.softmax(dim = -3)
|
||||||
|
|
||||||
|
# unfold values
|
||||||
|
|
||||||
|
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||||
|
v = F.unfold(v, kernel_size = wsz)
|
||||||
|
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||||
|
|
||||||
|
# aggregate values
|
||||||
|
|
||||||
|
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||||
|
|
||||||
|
# combine heads
|
||||||
|
|
||||||
|
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||||
|
out = self.to_out(out)
|
||||||
|
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||||
|
|
||||||
|
# return original input if one query
|
||||||
|
|
||||||
|
if is_one_query:
|
||||||
|
out = rearrange(out, 'b 1 ... -> b ...')
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
class QueryAttnUpsample(nn.Module):
|
||||||
|
def __init__(self, dim, **kwargs):
|
||||||
|
super().__init__()
|
||||||
|
self.norm = LayerNormChan(dim)
|
||||||
|
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.norm(x)
|
||||||
|
out = self.qna(x)
|
||||||
|
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||||
|
return out
|
||||||
@@ -17,6 +17,7 @@ from kornia.filters import gaussian_blur2d
|
|||||||
|
|
||||||
from dalle2_pytorch.tokenizer import tokenizer
|
from dalle2_pytorch.tokenizer import tokenizer
|
||||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||||
|
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||||
|
|
||||||
# use x-clip
|
# use x-clip
|
||||||
|
|
||||||
@@ -692,7 +693,7 @@ class DiffusionPrior(nn.Module):
|
|||||||
# decoder
|
# decoder
|
||||||
|
|
||||||
def Upsample(dim):
|
def Upsample(dim):
|
||||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
return QueryAttnUpsample(dim)
|
||||||
|
|
||||||
def Downsample(dim):
|
def Downsample(dim):
|
||||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ import torchvision
|
|||||||
|
|
||||||
from einops import rearrange, reduce, repeat
|
from einops import rearrange, reduce, repeat
|
||||||
|
|
||||||
|
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||||
|
|
||||||
# constants
|
# constants
|
||||||
|
|
||||||
MList = nn.ModuleList
|
MList = nn.ModuleList
|
||||||
@@ -243,111 +245,6 @@ class ResBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x) + x
|
return self.net(x) + x
|
||||||
|
|
||||||
# attention-based upsampling
|
|
||||||
# from https://arxiv.org/abs/2112.11435
|
|
||||||
|
|
||||||
class QueryAndAttend(nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
dim,
|
|
||||||
num_queries = 1,
|
|
||||||
dim_head = 32,
|
|
||||||
heads = 8,
|
|
||||||
window_size = 3
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.scale = dim_head ** -0.5
|
|
||||||
inner_dim = dim_head * heads
|
|
||||||
self.heads = heads
|
|
||||||
self.dim_head = dim_head
|
|
||||||
self.window_size = window_size
|
|
||||||
self.num_queries = num_queries
|
|
||||||
|
|
||||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
|
||||||
|
|
||||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
|
||||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
|
||||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
"""
|
|
||||||
einstein notation
|
|
||||||
b - batch
|
|
||||||
h - heads
|
|
||||||
l - num queries
|
|
||||||
d - head dimension
|
|
||||||
x - height
|
|
||||||
y - width
|
|
||||||
j - source sequence for attending to (kernel size squared in this case)
|
|
||||||
"""
|
|
||||||
|
|
||||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
|
||||||
batch, _, height, width = x.shape
|
|
||||||
|
|
||||||
is_one_query = self.num_queries == 1
|
|
||||||
|
|
||||||
# queries, keys, values
|
|
||||||
|
|
||||||
q = self.queries * self.scale
|
|
||||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
|
||||||
|
|
||||||
# similarities
|
|
||||||
|
|
||||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
|
||||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
|
||||||
|
|
||||||
# unfold the similarity scores, with float(-inf) as padding value
|
|
||||||
|
|
||||||
mask_value = -torch.finfo(sim.dtype).max
|
|
||||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
|
||||||
sim = F.unfold(sim, kernel_size = wsz)
|
|
||||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
|
||||||
|
|
||||||
# rel pos bias
|
|
||||||
|
|
||||||
sim = sim + self.rel_pos_bias
|
|
||||||
|
|
||||||
# numerically stable attention
|
|
||||||
|
|
||||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
|
||||||
attn = sim.softmax(dim = -3)
|
|
||||||
|
|
||||||
# unfold values
|
|
||||||
|
|
||||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
|
||||||
v = F.unfold(v, kernel_size = wsz)
|
|
||||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
|
||||||
|
|
||||||
# aggregate values
|
|
||||||
|
|
||||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
|
||||||
|
|
||||||
# combine heads
|
|
||||||
|
|
||||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
|
||||||
out = self.to_out(out)
|
|
||||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
|
||||||
|
|
||||||
# return original input if one query
|
|
||||||
|
|
||||||
if is_one_query:
|
|
||||||
out = rearrange(out, 'b 1 ... -> b ...')
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
class QueryAttnUpsample(nn.Module):
|
|
||||||
def __init__(self, dim, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.norm = LayerNormChan(dim)
|
|
||||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.norm(x)
|
|
||||||
out = self.qna(x)
|
|
||||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
|
||||||
return out
|
|
||||||
|
|
||||||
# vqgan attention layer
|
# vqgan attention layer
|
||||||
class VQGanAttention(nn.Module):
|
class VQGanAttention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
Reference in New Issue
Block a user