mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
19 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e024971dc3 | ||
|
|
4075d02139 | ||
|
|
de0296106b | ||
|
|
eafb136214 | ||
|
|
bfbcc283a3 | ||
|
|
c30544b73a | ||
|
|
bdf5e9c009 | ||
|
|
9878be760b | ||
|
|
7ba6357c05 | ||
|
|
76e063e8b7 | ||
|
|
4d25976f33 | ||
|
|
0b28ee0d01 | ||
|
|
45262a4bb7 | ||
|
|
13a58a78c4 | ||
|
|
f75d49c781 | ||
|
|
3b520dfa85 | ||
|
|
79198c6ae4 | ||
|
|
77a246b1b9 | ||
|
|
f93a3f6ed8 |
159
README.md
159
README.md
@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -246,13 +246,6 @@ loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
@@ -355,7 +348,8 @@ decoder = Decoder(
|
||||
image_sizes = (128, 256),
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # set this to True if you wish to condition on text during training and sampling
|
||||
).cuda()
|
||||
|
||||
for unet_number in (1, 2):
|
||||
@@ -383,6 +377,124 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Training on Preprocessed CLIP Embeddings
|
||||
|
||||
It is likely, when scaling up, that you would first preprocess your images and text into corresponding embeddings before training the prior network. You can do so easily by simply passing in `image_embed`, `text_embed`, and optionally `text_encodings` and `text_mask`
|
||||
|
||||
Working example below
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
|
||||
|
||||
# get trained CLIP from step one
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8,
|
||||
).cuda()
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = diffusion_prior.get_image_embed(images)
|
||||
clip_text_embeds = diffusion_prior.get_text_cond(text).get('text_embed')
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
You can also completely go `CLIP`-less, in which case you will need to pass in the `image_embed_dim` into the `DiffusionPrior` on initialization
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior
|
||||
|
||||
# setup prior network, which contains an autoregressive transformer
|
||||
|
||||
prior_network = DiffusionPriorNetwork(
|
||||
dim = 512,
|
||||
depth = 6,
|
||||
dim_head = 64,
|
||||
heads = 8
|
||||
).cuda()
|
||||
|
||||
# diffusion prior network, which contains the CLIP and network (with transformer) above
|
||||
|
||||
diffusion_prior = DiffusionPrior(
|
||||
net = prior_network,
|
||||
image_embed_dim = 512, # this needs to be set
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2,
|
||||
condition_on_text_encodings = False # this probably should be true, but just to get Laion started
|
||||
).cuda()
|
||||
|
||||
# mock data
|
||||
|
||||
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||
images = torch.randn(4, 3, 256, 256).cuda()
|
||||
|
||||
# precompute the text and image embeddings
|
||||
# here using the diffusion prior class, but could be done with CLIP alone
|
||||
|
||||
clip_image_embeds = torch.randn(4, 512).cuda()
|
||||
clip_text_embeds = torch.randn(4, 512).cuda()
|
||||
|
||||
# feed text and images into diffusion prior network
|
||||
|
||||
loss = diffusion_prior(
|
||||
text_embed = clip_text_embeds,
|
||||
image_embed = clip_image_embeds
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
|
||||
# do the above for many many many steps
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
@@ -530,13 +642,16 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [x] use attention-based upsampling https://arxiv.org/abs/2112.11435
|
||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -568,7 +683,7 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Liu2022ACF,
|
||||
title = {A ConvNet for the 2020s},
|
||||
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s},
|
||||
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
|
||||
year = {2022}
|
||||
}
|
||||
@@ -582,4 +697,24 @@ Once built, images will be saved to the same directory the command is invoked
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Arar2021LearnedQF,
|
||||
title = {Learned Queries for Efficient Local Attention},
|
||||
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2112.11435}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Yu2021VectorquantizedIM,
|
||||
title = {Vector-quantized Image Modeling with Improved VQGAN},
|
||||
author = {Jiahui Yu and Xin Li and Jing Yu Koh and Han Zhang and Ruoming Pang and James Qin and Alexander Ku and Yuanzhong Xu and Jason Baldridge and Yonghui Wu},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2110.04627}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
130
dalle2_pytorch/attention.py
Normal file
130
dalle2_pytorch/attention.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# attention-based upsampling
|
||||
# from https://arxiv.org/abs/2112.11435
|
||||
|
||||
class QueryAndAttend(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
num_queries = 1,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
window_size = 3
|
||||
):
|
||||
super().__init__()
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
self.window_size = window_size
|
||||
self.num_queries = num_queries
|
||||
|
||||
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
|
||||
|
||||
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
|
||||
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Conv2d(inner_dim, dim * 2, 1, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Conv2d(dim * 2, dim, 1, bias = False)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
einstein notation
|
||||
b - batch
|
||||
h - heads
|
||||
l - num queries
|
||||
d - head dimension
|
||||
x - height
|
||||
y - width
|
||||
j - source sequence for attending to (kernel size squared in this case)
|
||||
"""
|
||||
|
||||
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
|
||||
batch, _, height, width = x.shape
|
||||
|
||||
is_one_query = self.num_queries == 1
|
||||
|
||||
# queries, keys, values
|
||||
|
||||
q = self.queries * self.scale
|
||||
k, v = self.to_kv(x).chunk(2, dim = 1)
|
||||
|
||||
# similarities
|
||||
|
||||
sim = einsum('h l d, b d x y -> b h l x y', q, k)
|
||||
sim = rearrange(sim, 'b ... x y -> b (...) x y')
|
||||
|
||||
# unfold the similarity scores, with float(-inf) as padding value
|
||||
|
||||
mask_value = -torch.finfo(sim.dtype).max
|
||||
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
|
||||
sim = F.unfold(sim, kernel_size = wsz)
|
||||
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
|
||||
|
||||
# rel pos bias
|
||||
|
||||
sim = sim + self.rel_pos_bias
|
||||
|
||||
# numerically stable attention
|
||||
|
||||
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -3)
|
||||
|
||||
# unfold values
|
||||
|
||||
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
|
||||
v = F.unfold(v, kernel_size = wsz)
|
||||
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
|
||||
|
||||
# aggregate values
|
||||
|
||||
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
|
||||
|
||||
# combine heads
|
||||
|
||||
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
|
||||
out = self.to_out(out)
|
||||
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
|
||||
|
||||
# return original input if one query
|
||||
|
||||
if is_one_query:
|
||||
out = rearrange(out, 'b 1 ... -> b ...')
|
||||
|
||||
return out
|
||||
|
||||
class QueryAttnUpsample(nn.Module):
|
||||
def __init__(self, dim, **kwargs):
|
||||
super().__init__()
|
||||
self.norm = LayerNormChan(dim)
|
||||
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
out = self.qna(x)
|
||||
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
|
||||
return out
|
||||
@@ -17,6 +17,7 @@ from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -83,7 +84,7 @@ def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://
|
||||
if orig_image_size == shape:
|
||||
return t
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode)
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
@@ -142,6 +143,92 @@ def sigmoid_beta_schedule(timesteps):
|
||||
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
|
||||
|
||||
|
||||
class BaseGaussianDiffusion(nn.Module):
|
||||
def __init__(self, *, beta_schedule, timesteps, loss_type):
|
||||
super().__init__()
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def sample(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
@@ -420,25 +507,41 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
image_embed,
|
||||
diffusion_timesteps,
|
||||
*,
|
||||
text_encodings,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
mask = None,
|
||||
cond_drop_prob = 0.2
|
||||
):
|
||||
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
|
||||
|
||||
# make text encodings optional
|
||||
# although the paper seems to suggest it is present <--
|
||||
|
||||
if not exists(text_encodings):
|
||||
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
mask &= cond_prob_mask
|
||||
|
||||
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
|
||||
|
||||
mask = torch.cat((mask, cond_prob_mask), dim = 1)
|
||||
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
if exists(mask):
|
||||
not_all_masked_out = mask.any(dim = -1)
|
||||
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
|
||||
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
@@ -454,16 +557,6 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
# mask if it doesn't exist
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
@@ -474,81 +567,50 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
|
||||
return pred_image_embed
|
||||
|
||||
class DiffusionPrior(nn.Module):
|
||||
class DiffusionPrior(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
clip = None,
|
||||
image_embed_dim = None,
|
||||
image_size = None,
|
||||
image_channels = 3,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
if exists(clip):
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
else:
|
||||
assert exists(image_embed_dim), 'latent dimension must be given, if training prior network without CLIP given'
|
||||
self.clip = None
|
||||
|
||||
self.net = net
|
||||
self.image_embed_dim = clip.dim_latent
|
||||
self.channels = clip.image_channels
|
||||
self.image_size = clip.image_size
|
||||
self.image_embed_dim = default(image_embed_dim, lambda: clip.dim_latent)
|
||||
self.channels = default(image_channels, lambda: clip.image_channels)
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
assert exists(self.clip)
|
||||
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
@@ -556,33 +618,18 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
assert exists(self.clip)
|
||||
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
text_embed = self.clip.to_text_latent(text_cls)
|
||||
text_embed = l2norm(text_embed)
|
||||
|
||||
if not self.condition_on_text_encodings:
|
||||
return dict(text_embed = text_embed)
|
||||
|
||||
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
@@ -619,14 +666,6 @@ class DiffusionPrior(nn.Module):
|
||||
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, image_embed, t, text_cond, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(image_embed))
|
||||
|
||||
@@ -678,13 +717,41 @@ class DiffusionPrior(nn.Module):
|
||||
top_image_embeds = image_embeds.gather(1, top_sim_indices)
|
||||
return rearrange(top_image_embeds, 'b 1 d -> b d')
|
||||
|
||||
def forward(self, text, image, *args, **kwargs):
|
||||
b, device, img_size, = image.shape[0], image.device, self.image_size
|
||||
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
|
||||
def forward(
|
||||
self,
|
||||
text = None,
|
||||
image = None,
|
||||
text_embed = None, # allow for training on preprocessed CLIP text and image embeddings
|
||||
image_embed = None,
|
||||
text_encodings = None, # as well as CLIP text encodings
|
||||
text_mask = None, # text mask <- may eventually opt for the learned padding tokens technique from DALL-E1 to reduce complexity
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
||||
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
||||
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_cond = self.get_text_cond(text)
|
||||
if exists(image):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
# calculate text conditionings, based on what is passed in
|
||||
|
||||
if exists(text):
|
||||
text_cond = self.get_text_cond(text)
|
||||
else:
|
||||
text_cond = dict(
|
||||
text_embed = text_embed,
|
||||
text_encodings = text_encodings,
|
||||
mask = text_mask
|
||||
)
|
||||
|
||||
# timestep conditioning from ddpm
|
||||
|
||||
batch, device = image_embed.shape[0], image_embed.device
|
||||
times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long)
|
||||
|
||||
# calculate forward loss
|
||||
|
||||
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
|
||||
return loss
|
||||
@@ -839,6 +906,7 @@ class Unet(nn.Module):
|
||||
dim,
|
||||
*,
|
||||
image_embed_dim,
|
||||
text_embed_dim = None,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
num_time_tokens = 2,
|
||||
@@ -852,6 +920,7 @@ class Unet(nn.Module):
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
max_text_len = 256,
|
||||
cond_on_image_embeds = False,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -891,7 +960,7 @@ class Unet(nn.Module):
|
||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||
) if image_embed_dim != cond_dim else nn.Identity()
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
@@ -902,7 +971,7 @@ class Unet(nn.Module):
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
@@ -1030,7 +1099,7 @@ class Unet(nn.Module):
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
text_tokens,
|
||||
self.null_text_embed
|
||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
||||
)
|
||||
|
||||
# main conditioning tokens (c)
|
||||
@@ -1110,13 +1179,13 @@ class LowresConditioner(nn.Module):
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(nn.Module):
|
||||
class Decoder(BaseGaussianDiffusion):
|
||||
def __init__(
|
||||
self,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = None,
|
||||
vae = tuple(),
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
@@ -1128,14 +1197,22 @@ class Decoder(nn.Module):
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
beta_schedule = beta_schedule,
|
||||
timesteps = timesteps,
|
||||
loss_type = loss_type
|
||||
)
|
||||
|
||||
assert isinstance(clip, CLIP)
|
||||
freeze_model_and_make_eval_(clip)
|
||||
self.clip = clip
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
|
||||
self.condition_on_text_encodings = condition_on_text_encodings
|
||||
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
@@ -1191,55 +1268,6 @@ class Decoder(nn.Module):
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
# noise schedule
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "quadratic":
|
||||
betas = quadratic_beta_schedule(timesteps)
|
||||
elif beta_schedule == "jsd":
|
||||
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
|
||||
elif beta_schedule == "sigmoid":
|
||||
betas = sigmoid_beta_schedule(timesteps)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
self.register_buffer('betas', betas)
|
||||
self.register_buffer('alphas_cumprod', alphas_cumprod)
|
||||
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
self.register_buffer('posterior_variance', posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
@@ -1272,27 +1300,6 @@ class Decoder(nn.Module):
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
def q_mean_variance(self, x_start, t):
|
||||
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
||||
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
||||
return mean, variance, log_variance
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
||||
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
@@ -1337,14 +1344,6 @@ class Decoder(nn.Module):
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
@@ -1379,6 +1378,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
img = None
|
||||
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
@@ -1439,6 +1440,8 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
@@ -1466,7 +1469,9 @@ class DALLE2(nn.Module):
|
||||
assert isinstance(decoder, Decoder)
|
||||
self.prior = prior
|
||||
self.decoder = decoder
|
||||
|
||||
self.prior_num_samples = prior_num_samples
|
||||
self.decoder_need_text_cond = self.decoder.condition_on_text_encodings
|
||||
|
||||
@torch.no_grad()
|
||||
@eval_decorator
|
||||
@@ -1483,7 +1488,9 @@ class DALLE2(nn.Module):
|
||||
text = tokenizer.tokenize(text).to(device)
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
text_cond = text if self.decoder_need_text_cond else None
|
||||
images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
import copy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
# exponential moving average wrapper
|
||||
|
||||
class EMA(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
beta = 0.99,
|
||||
ema_update_after_step = 1000,
|
||||
ema_update_every = 10,
|
||||
):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.online_model = model
|
||||
self.ema_model = copy.deepcopy(model)
|
||||
|
||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
||||
self.ema_update_every = ema_update_every
|
||||
|
||||
self.register_buffer('initted', torch.Tensor([False]))
|
||||
self.register_buffer('step', torch.tensor([0.]))
|
||||
|
||||
def update(self):
|
||||
self.step += 1
|
||||
|
||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
||||
return
|
||||
|
||||
if not self.initted:
|
||||
self.ema_model.state_dict(self.online_model.state_dict())
|
||||
self.initted.data.copy_(torch.Tensor([True]))
|
||||
|
||||
self.update_moving_average(self.ema_model, self.online_model)
|
||||
|
||||
def update_moving_average(ma_model, current_model):
|
||||
def calculate_ema(beta, old, new):
|
||||
if not exists(old):
|
||||
return new
|
||||
return old * beta + (1 - beta) * new
|
||||
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = calculate_ema(self.beta, old_weight, up_weight)
|
||||
|
||||
for current_buffer, ma_buffer in zip(current_model.buffers(), ma_model.buffers()):
|
||||
new_buffer_value = calculate_ema(self.beta, ma_buffer, current_buffer)
|
||||
ma_buffer.copy_(new_buffer_value)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.ema_model(*args, **kwargs)
|
||||
|
||||
@@ -12,6 +12,10 @@ from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
from einops_exts import rearrange_many
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from dalle2_pytorch.attention import QueryAttnUpsample
|
||||
|
||||
# constants
|
||||
|
||||
@@ -144,6 +148,8 @@ class LayerNormChan(nn.Module):
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
# discriminator
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -177,6 +183,8 @@ class Discriminator(nn.Module):
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
# positional encoding
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
@@ -211,6 +219,84 @@ class ContinuousPositionBias(nn.Module):
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
# resnet encoder / decoder
|
||||
|
||||
class ResnetEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
num_resnet_blocks = 1,
|
||||
resnet_groups = 16,
|
||||
first_conv_kernel_size = 5,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
attn_dropout = 0.,
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
self.layers = layers
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def encode(self, x):
|
||||
for enc in self.encoders:
|
||||
x = enc(x)
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
for dec in self.decoders:
|
||||
x = dec(x)
|
||||
return x
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
@@ -243,6 +329,8 @@ class ResBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
# vqgan attention layer
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -287,6 +375,147 @@ class VQGanAttention(nn.Module):
|
||||
|
||||
return out + residual
|
||||
|
||||
# ViT encoder / decoder
|
||||
|
||||
class RearrangeImage(nn.Module):
|
||||
def forward(self, x):
|
||||
n = x.shape[1]
|
||||
w = h = int(sqrt(n))
|
||||
return rearrange(x, 'b (h w) ... -> b h w ...', h = h, w = w)
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
heads = 8,
|
||||
dim_head = 32
|
||||
):
|
||||
super().__init__()
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
|
||||
self.to_out = nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
|
||||
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
|
||||
|
||||
q = q * self.scale
|
||||
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
||||
|
||||
sim = sim - sim.amax(dim = -1, keepdim = True).detach()
|
||||
attn = sim.softmax(dim = -1)
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
def FeedForward(dim, mult = 4):
|
||||
return nn.Sequential(
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, dim * mult, bias = False),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * mult, dim, bias = False)
|
||||
)
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
*,
|
||||
layers,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(layers):
|
||||
self.layers.append(nn.ModuleList([
|
||||
Attention(dim = dim, dim_head = dim_head, heads = heads),
|
||||
FeedForward(dim = dim, mult = ff_mult)
|
||||
]))
|
||||
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
for attn, ff in self.layers:
|
||||
x = attn(x) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
|
||||
class ViTEncDec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
patch_size = 8,
|
||||
dim_head = 32,
|
||||
heads = 8,
|
||||
ff_mult = 4
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = dim
|
||||
self.patch_size = patch_size
|
||||
|
||||
input_dim = channels * (patch_size ** 2)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
|
||||
nn.Linear(input_dim, dim),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w c -> b c h w')
|
||||
)
|
||||
|
||||
self.decoder = nn.Sequential(
|
||||
Rearrange('b c h w -> b (h w) c'),
|
||||
Transformer(
|
||||
dim = dim,
|
||||
dim_head = dim_head,
|
||||
heads = heads,
|
||||
ff_mult = ff_mult,
|
||||
layers = layers
|
||||
),
|
||||
nn.Sequential(
|
||||
nn.Linear(dim, dim * 2, bias = False),
|
||||
nn.Tanh(),
|
||||
nn.Linear(dim * 2, dim, bias = False),
|
||||
),
|
||||
nn.LayerNorm(dim),
|
||||
nn.Linear(dim, input_dim),
|
||||
RearrangeImage(),
|
||||
Rearrange('b h w (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size)
|
||||
)
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // self.patch_size
|
||||
|
||||
def encode(self, x):
|
||||
return self.encoder(x)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
# main vqgan-vae classes
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -317,81 +546,43 @@ class VQGanVAE(nn.Module):
|
||||
image_size,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
num_resnet_blocks = 1,
|
||||
vgg = None,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
resnet_groups = 16,
|
||||
attn_dropout = 0.,
|
||||
first_conv_kernel_size = 5,
|
||||
use_vgg_and_gan = True,
|
||||
vae_type = 'resnet',
|
||||
discr_layers = 4,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
encdec_kwargs, kwargs = groupby_prefix_and_trim('encdec_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
if vae_type == 'resnet':
|
||||
enc_dec_klass = ResnetEncDec
|
||||
elif vae_type == 'vit':
|
||||
enc_dec_klass = ViTEncDec
|
||||
else:
|
||||
raise ValueError(f'{vae_type} not valid')
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
self.enc_dec = enc_dec_klass(
|
||||
dim = dim,
|
||||
channels = channels,
|
||||
layers = layers,
|
||||
**encdec_kwargs
|
||||
)
|
||||
|
||||
self.vq = VQ(
|
||||
dim = codebook_dim,
|
||||
dim = self.enc_dec.encoded_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
@@ -424,13 +615,21 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
# gan related losses
|
||||
|
||||
layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
@property
|
||||
def encoded_dim(self):
|
||||
return self.enc_dec.encoded_dim
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
return self.enc_dec.get_encoded_fmap_size(image_size)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
@@ -456,16 +655,13 @@ class VQGanVAE(nn.Module):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
for enc in self.encoders:
|
||||
fmap = enc(fmap)
|
||||
|
||||
fmap = self.enc_dec.encode(fmap)
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
fmap = self.enc_dec.decode(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
Reference in New Issue
Block a user