mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 | ||
|
|
f75d49c781 |
12
README.md
12
README.md
@@ -523,6 +523,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
@@ -531,7 +532,6 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [ ] experiment with https://arxiv.org/abs/2112.11435 as upsampler, test in https://github.com/lucidrains/lightweight-gan first
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -577,14 +577,4 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arar2021LearnedQF,
|
||||
title = {Learned Queries for Efficient Local Attention},
|
||||
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2112.11435}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
125
dalle2_pytorch/attention.py
Normal file
125
dalle2_pytorch/attention.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
@@ -17,6 +17,7 @@ from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -1116,7 +1117,7 @@ class Decoder(nn.Module):
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = None,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
@@ -13,6 +13,8 @@ import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
@@ -243,111 +245,6 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
|
||||
# vqgan attention layer
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
@@ -481,7 +378,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
Reference in New Issue
Block a user