mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 19:44:26 +01:00
Compare commits
27 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f2a0c7e00 | ||
|
|
863f4ef243 | ||
|
|
fb8a66a2de | ||
|
|
579d4b42dd | ||
|
|
473808850a | ||
|
|
d5318aef4f | ||
|
|
f82917e1fd | ||
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 | ||
|
|
f2d5b87677 | ||
|
|
461347c171 | ||
|
|
46cef31c86 | ||
|
|
59b1a77d4d | ||
|
|
7f338319fd | ||
|
|
2c6c91829d | ||
|
|
ad17c69ab6 | ||
|
|
0b4ec34efb | ||
|
|
f027b82e38 | ||
|
|
8cc9016cb0 | ||
|
|
1d8f37befe | ||
|
|
faebf4c8b8 | ||
|
|
b8e8d3c164 | ||
|
|
8e2416b49b | ||
|
|
f37c26e856 |
149
README.md
149
README.md
@@ -1,6 +1,6 @@
|
||||
<img src="./dalle2.png" width="450px"></img>
|
||||
|
||||
## DALL-E 2 - Pytorch (wip)
|
||||
## DALL-E 2 - Pytorch
|
||||
|
||||
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
|
||||
|
||||
@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
@@ -385,7 +383,127 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(3):
|
||||
loss = decoder(images, unet_number = 3)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
@@ -393,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
@@ -412,9 +528,15 @@ Offer training wrappers
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -452,17 +574,6 @@ Offer training wrappers
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@misc{zhang2019root,
|
||||
title = {Root Mean Square Layer Normalization},
|
||||
author = {Biao Zhang and Rico Sennrich},
|
||||
year = {2019},
|
||||
eprint = {1910.07467},
|
||||
archivePrefix = {arXiv},
|
||||
primaryClass = {cs.LG}
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
|
||||
@@ -1,2 +1,4 @@
|
||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from x_clip import CLIP
|
||||
|
||||
@@ -1,9 +1,51 @@
|
||||
import click
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from pathlib import Path
|
||||
|
||||
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
|
||||
|
||||
def safeget(dictionary, keys, default = None):
|
||||
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
|
||||
|
||||
def simple_slugify(text, max_length = 255):
|
||||
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
|
||||
|
||||
def get_pkg_version():
|
||||
from pkg_resources import get_distribution
|
||||
return get_distribution('dalle2_pytorch').version
|
||||
|
||||
def main():
|
||||
pass
|
||||
|
||||
@click.command()
|
||||
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
|
||||
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return 'not ready yet'
|
||||
def dream(
|
||||
model,
|
||||
cond_scale,
|
||||
text
|
||||
):
|
||||
model_path = Path(model)
|
||||
full_model_path = str(model_path.resolve())
|
||||
assert model_path.exists(), f'model not found at {full_model_path}'
|
||||
loaded = torch.load(str(model_path))
|
||||
|
||||
version = safeget(loaded, 'version')
|
||||
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
|
||||
|
||||
prior_init_params = safeget(loaded, 'init_params.prior')
|
||||
decoder_init_params = safeget(loaded, 'init_params.decoder')
|
||||
model_params = safeget(loaded, 'model_params')
|
||||
|
||||
prior = DiffusionPrior(**prior_init_params)
|
||||
decoder = Decoder(**decoder_init_params)
|
||||
|
||||
dalle2 = DALLE2(prior, decoder)
|
||||
dalle2.load_state_dict(model_params)
|
||||
|
||||
image = dalle2(text, cond_scale = cond_scale)
|
||||
|
||||
pil_image = T.ToPILImage()(image)
|
||||
return pil_image.save(f'./{simple_slugify(text)}.png')
|
||||
|
||||
@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -48,6 +49,12 @@ def is_list_str(x):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
def pad_tuple_to_length(t, length, fillvalue = None):
|
||||
remain_length = length - len(t)
|
||||
if remain_length <= 0:
|
||||
return t
|
||||
return (*t, *((fillvalue,) * remain_length))
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -137,23 +144,27 @@ def sigmoid_beta_schedule(timesteps):
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.register_buffer("beta", torch.zeros(dim))
|
||||
|
||||
def forward(self, x):
|
||||
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
|
||||
|
||||
|
||||
class ChanLayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(dim))
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * self.gamma * self.scale
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g
|
||||
|
||||
class ChanRMSNorm(RMSNorm):
|
||||
def forward(self, x):
|
||||
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
|
||||
inv_norm = torch.rsqrt(squared_sum + self.eps)
|
||||
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
@@ -249,10 +260,10 @@ def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
|
||||
|
||||
inner_dim = int(mult * dim)
|
||||
return nn.Sequential(
|
||||
RMSNorm(dim),
|
||||
LayerNorm(dim),
|
||||
nn.Linear(dim, inner_dim * 2, bias = False),
|
||||
SwiGLU(),
|
||||
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
LayerNorm(inner_dim) if post_activation_norm else nn.Identity(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim, bias = False)
|
||||
)
|
||||
@@ -275,7 +286,8 @@ class Attention(nn.Module):
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.causal = causal
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.post_norm = LayerNorm(dim) # sandwich norm from Coqview paper + Normformer
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -331,7 +343,8 @@ class Attention(nn.Module):
|
||||
out = einsum('b h i j, b j d -> b h i d', attn, v)
|
||||
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
out = self.to_out(out)
|
||||
return self.post_norm(out)
|
||||
|
||||
class CausalTransformer(nn.Module):
|
||||
def __init__(
|
||||
@@ -344,7 +357,8 @@ class CausalTransformer(nn.Module):
|
||||
ff_mult = 4,
|
||||
norm_out = False,
|
||||
attn_dropout = 0.,
|
||||
ff_dropout = 0.
|
||||
ff_dropout = 0.,
|
||||
final_proj = True
|
||||
):
|
||||
super().__init__()
|
||||
self.rel_pos_bias = RelPosBias(heads = heads)
|
||||
@@ -356,7 +370,8 @@ class CausalTransformer(nn.Module):
|
||||
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
|
||||
]))
|
||||
|
||||
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.norm = LayerNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
|
||||
self.project_out = nn.Linear(dim, dim, bias = False) if final_proj else nn.Identity()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -371,7 +386,8 @@ class CausalTransformer(nn.Module):
|
||||
x = attn(x, mask = mask, attn_bias = attn_bias) + x
|
||||
x = ff(x) + x
|
||||
|
||||
return self.norm(x)
|
||||
out = self.norm(x)
|
||||
return self.project_out(out)
|
||||
|
||||
class DiffusionPriorNetwork(nn.Module):
|
||||
def __init__(
|
||||
@@ -464,11 +480,11 @@ class DiffusionPrior(nn.Module):
|
||||
net,
|
||||
*,
|
||||
clip,
|
||||
timesteps=1000,
|
||||
cond_drop_prob=0.2,
|
||||
loss_type="l1",
|
||||
predict_x0=True,
|
||||
beta_schedule="cosine",
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = "l1",
|
||||
predict_x_start = True,
|
||||
beta_schedule = "cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -481,7 +497,7 @@ class DiffusionPrior(nn.Module):
|
||||
self.image_size = clip.image_size
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
self.predict_x0 = predict_x0
|
||||
self.predict_x_start = predict_x_start
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
@@ -498,7 +514,7 @@ class DiffusionPrior(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -531,12 +547,14 @@ class DiffusionPrior(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
@@ -566,14 +584,16 @@ class DiffusionPrior(nn.Module):
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||
if self.predict_x0:
|
||||
x_recon = self.net(x, t, **text_cond)
|
||||
pred = self.net(x, t, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_recon = pred
|
||||
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
|
||||
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised:
|
||||
if clip_denoised and not self.predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
@@ -619,7 +639,7 @@ class DiffusionPrior(nn.Module):
|
||||
**text_cond
|
||||
)
|
||||
|
||||
to_predict = noise if not self.predict_x0 else image_embed
|
||||
to_predict = noise if not self.predict_x_start else image_embed
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(to_predict, x_recon)
|
||||
@@ -720,7 +740,7 @@ class ConvNextBlock(nn.Module):
|
||||
|
||||
inner_dim = int(dim_out * mult)
|
||||
self.net = nn.Sequential(
|
||||
ChanRMSNorm(dim) if norm else nn.Identity(),
|
||||
ChanLayerNorm(dim) if norm else nn.Identity(),
|
||||
nn.Conv2d(dim, inner_dim, 3, padding = 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
|
||||
@@ -756,8 +776,8 @@ class CrossAttention(nn.Module):
|
||||
|
||||
context_dim = default(context_dim, dim)
|
||||
|
||||
self.norm = RMSNorm(dim)
|
||||
self.norm_context = RMSNorm(context_dim)
|
||||
self.norm = LayerNorm(dim)
|
||||
self.norm_context = LayerNorm(context_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
|
||||
@@ -825,10 +845,9 @@ class Unet(nn.Module):
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
attn_dim_head = 32,
|
||||
attn_heads = 8,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
@@ -845,9 +864,6 @@ class Unet(nn.Module):
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -888,6 +904,10 @@ class Unet(nn.Module):
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
|
||||
|
||||
# attention related params
|
||||
|
||||
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
@@ -901,7 +921,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -909,7 +929,7 @@ class Unet(nn.Module):
|
||||
mid_dim = dims[-1]
|
||||
|
||||
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
|
||||
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
|
||||
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
@@ -918,7 +938,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
@@ -931,11 +951,16 @@ class Unet(nn.Module):
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
def cast_model_parameters(
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -971,13 +996,6 @@ class Unet(nn.Module):
|
||||
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
|
||||
# time conditioning
|
||||
@@ -1054,17 +1072,62 @@ class Unet(nn.Module):
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cond_upsample_mode = 'bilinear',
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_upsample_mode = cond_upsample_mode
|
||||
self.downsample_first = downsample_first
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
cond_fmap,
|
||||
*,
|
||||
target_image_size,
|
||||
downsample_image_size = None,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
blur_sigma = default(blur_sigma, self.blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = None,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
beta_schedule = 'cosine',
|
||||
image_sizes = None # for cascading ddpm, image size at each stage
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(clip, CLIP)
|
||||
@@ -1076,11 +1139,28 @@ class Decoder(nn.Module):
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# unet image sizes
|
||||
|
||||
@@ -1089,12 +1169,30 @@ class Decoder(nn.Module):
|
||||
|
||||
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# cascading ddpm related stuff
|
||||
|
||||
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
self.cond_drop_prob = cond_drop_prob
|
||||
|
||||
# noise schedule
|
||||
|
||||
if beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
elif beta_schedule == "linear":
|
||||
@@ -1109,7 +1207,7 @@ class Decoder(nn.Module):
|
||||
raise NotImplementedError()
|
||||
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod = torch.cumprod(alphas, axis = 0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
|
||||
|
||||
timesteps, = betas.shape
|
||||
@@ -1142,24 +1240,31 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number):
|
||||
def get_unet(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
return self.unets[index]
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number = None, unet = None):
|
||||
assert exists(unet_number) ^ exists(unet)
|
||||
|
||||
if exists(unet_number):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
self.cuda()
|
||||
self.unets.cpu()
|
||||
|
||||
unet = self.unets[index]
|
||||
unet.cuda()
|
||||
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
self.unets.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
@@ -1188,34 +1293,48 @@ class Decoder(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
|
||||
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
|
||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
if clip_denoised:
|
||||
if predict_x_start:
|
||||
x_recon = pred
|
||||
else:
|
||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||
|
||||
if clip_denoised and not predict_x_start:
|
||||
x_recon.clamp_(-1., 1.)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
|
||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device = device)
|
||||
|
||||
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
|
||||
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||
img = self.p_sample(
|
||||
unet,
|
||||
img,
|
||||
torch.full((b,), i, device = device, dtype = torch.long),
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
predict_x_start = predict_x_start
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
@@ -1226,7 +1345,7 @@ class Decoder(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
|
||||
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
|
||||
@@ -1240,12 +1359,14 @@ class Decoder(nn.Module):
|
||||
cond_drop_prob = self.cond_drop_prob
|
||||
)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
|
||||
if self.loss_type == 'l1':
|
||||
loss = F.l1_loss(noise, x_recon)
|
||||
loss = F.l1_loss(target, x_recon)
|
||||
elif self.loss_type == 'l2':
|
||||
loss = F.mse_loss(noise, x_recon)
|
||||
loss = F.mse_loss(target, x_recon)
|
||||
elif self.loss_type == "huber":
|
||||
loss = F.smooth_l1_loss(noise, x_recon)
|
||||
loss = F.smooth_l1_loss(target, x_recon)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -1255,27 +1376,56 @@ class Decoder(nn.Module):
|
||||
@eval_decorator
|
||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
||||
batch_size = image_embed.shape[0]
|
||||
channels = self.channels
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
img = None
|
||||
|
||||
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
||||
with self.one_unet_in_gpu(ind + 1):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
img = vae.decode(img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
def forward(
|
||||
self,
|
||||
image,
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
unet_index = unet_number - 1
|
||||
|
||||
index = unet_number - 1
|
||||
unet = self.unets[index]
|
||||
target_image_size = self.image_sizes[index]
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1289,9 +1439,17 @@ class Decoder(nn.Module):
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||
|
||||
# main class
|
||||
|
||||
@@ -1318,6 +1476,7 @@ class DALLE2(nn.Module):
|
||||
cond_scale = 1.
|
||||
):
|
||||
device = next(self.parameters()).device
|
||||
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
|
||||
|
||||
if isinstance(text, str) or is_list_str(text):
|
||||
text = [text] if not isinstance(text, (list, tuple)) else text
|
||||
@@ -1325,4 +1484,8 @@ class DALLE2(nn.Module):
|
||||
|
||||
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
|
||||
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
|
||||
|
||||
if one_text:
|
||||
return images[0]
|
||||
|
||||
return images
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
class LatentDiffusion(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
563
dalle2_pytorch/vqgan_vae.py
Normal file
563
dalle2_pytorch/vqgan_vae.py
Normal file
@@ -0,0 +1,563 @@
|
||||
import copy
|
||||
import math
|
||||
from math import sqrt
|
||||
from functools import partial, wraps
|
||||
|
||||
from vector_quantize_pytorch import VectorQuantize as VQ
|
||||
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import grad as torch_grad
|
||||
import torchvision
|
||||
|
||||
from einops import rearrange, reduce, repeat
|
||||
|
||||
# constants
|
||||
|
||||
MList = nn.ModuleList
|
||||
|
||||
# helper functions
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
# decorators
|
||||
|
||||
def eval_decorator(fn):
|
||||
def inner(model, *args, **kwargs):
|
||||
was_training = model.training
|
||||
model.eval()
|
||||
out = fn(model, *args, **kwargs)
|
||||
model.train(was_training)
|
||||
return out
|
||||
return inner
|
||||
|
||||
def remove_vgg(fn):
|
||||
@wraps(fn)
|
||||
def inner(self, *args, **kwargs):
|
||||
has_vgg = hasattr(self, 'vgg')
|
||||
if has_vgg:
|
||||
vgg = self.vgg
|
||||
delattr(self, 'vgg')
|
||||
|
||||
out = fn(self, *args, **kwargs)
|
||||
|
||||
if has_vgg:
|
||||
self.vgg = vgg
|
||||
|
||||
return out
|
||||
return inner
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(),dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
# tensor helper functions
|
||||
|
||||
def log(t, eps = 1e-10):
|
||||
return torch.log(t + eps)
|
||||
|
||||
def gradient_penalty(images, output, weight = 10):
|
||||
batch_size = images.shape[0]
|
||||
gradients = torch_grad(outputs = output, inputs = images,
|
||||
grad_outputs = torch.ones(output.size(), device = images.device),
|
||||
create_graph = True, retain_graph = True, only_inputs = True)[0]
|
||||
|
||||
gradients = rearrange(gradients, 'b ... -> b (...)')
|
||||
return weight * ((gradients.norm(2, dim = 1) - 1) ** 2).mean()
|
||||
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def leaky_relu(p = 0.1):
|
||||
return nn.LeakyReLU(0.1)
|
||||
|
||||
def stable_softmax(t, dim = -1, alpha = 32 ** 2):
|
||||
t = t / alpha
|
||||
t = t - torch.amax(t, dim = dim, keepdim = True).detach()
|
||||
return (t * alpha).softmax(dim = dim)
|
||||
|
||||
def safe_div(numer, denom, eps = 1e-8):
|
||||
return numer / (denom + eps)
|
||||
|
||||
# gan losses
|
||||
|
||||
def hinge_discr_loss(fake, real):
|
||||
return (F.relu(1 + fake) + F.relu(1 - real)).mean()
|
||||
|
||||
def hinge_gen_loss(fake):
|
||||
return -fake.mean()
|
||||
|
||||
def bce_discr_loss(fake, real):
|
||||
return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
|
||||
|
||||
def bce_gen_loss(fake):
|
||||
return -log(torch.sigmoid(fake)).mean()
|
||||
|
||||
def grad_layer_wrt_loss(loss, layer):
|
||||
return torch_grad(
|
||||
outputs = loss,
|
||||
inputs = layer,
|
||||
grad_outputs = torch.ones_like(loss),
|
||||
retain_graph = True
|
||||
)[0].detach()
|
||||
|
||||
# vqgan vae
|
||||
|
||||
class LayerNormChan(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
eps = 1e-5
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
|
||||
mean = torch.mean(x, dim = 1, keepdim = True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.gamma
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
channels = 3,
|
||||
groups = 16,
|
||||
init_kernel_size = 5
|
||||
):
|
||||
super().__init__()
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
self.layers = MList([nn.Sequential(nn.Conv2d(channels, dims[0], init_kernel_size, padding = init_kernel_size // 2), leaky_relu())])
|
||||
|
||||
for dim_in, dim_out in dim_pairs:
|
||||
self.layers.append(nn.Sequential(
|
||||
nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1),
|
||||
nn.GroupNorm(groups, dim_out),
|
||||
leaky_relu()
|
||||
))
|
||||
|
||||
dim = dims[-1]
|
||||
self.to_logits = nn.Sequential( # return 5 x 5, for PatchGAN-esque training
|
||||
nn.Conv2d(dim, dim, 1),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(dim, 1, 4)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for net in self.layers:
|
||||
x = net(x)
|
||||
|
||||
return self.to_logits(x)
|
||||
|
||||
class ContinuousPositionBias(nn.Module):
|
||||
""" from https://arxiv.org/abs/2111.09883 """
|
||||
|
||||
def __init__(self, *, dim, heads, layers = 2):
|
||||
super().__init__()
|
||||
self.net = MList([])
|
||||
self.net.append(nn.Sequential(nn.Linear(2, dim), leaky_relu()))
|
||||
|
||||
for _ in range(layers - 1):
|
||||
self.net.append(nn.Sequential(nn.Linear(dim, dim), leaky_relu()))
|
||||
|
||||
self.net.append(nn.Linear(dim, heads))
|
||||
self.register_buffer('rel_pos', None, persistent = False)
|
||||
|
||||
def forward(self, x):
|
||||
n, device = x.shape[-1], x.device
|
||||
fmap_size = int(sqrt(n))
|
||||
|
||||
if not exists(self.rel_pos):
|
||||
pos = torch.arange(fmap_size, device = device)
|
||||
grid = torch.stack(torch.meshgrid(pos, pos, indexing = 'ij'))
|
||||
grid = rearrange(grid, 'c i j -> (i j) c')
|
||||
rel_pos = rearrange(grid, 'i c -> i 1 c') - rearrange(grid, 'j c -> 1 j c')
|
||||
rel_pos = torch.sign(rel_pos) * torch.log(rel_pos.abs() + 1)
|
||||
self.register_buffer('rel_pos', rel_pos, persistent = False)
|
||||
|
||||
rel_pos = self.rel_pos.float()
|
||||
|
||||
for layer in self.net:
|
||||
rel_pos = layer(rel_pos)
|
||||
|
||||
bias = rearrange(rel_pos, 'i j h -> h i j')
|
||||
return x + bias
|
||||
|
||||
class GLUResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan * 2, 3, padding = 1),
|
||||
nn.GLU(dim = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, chan, groups = 16):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 3, padding = 1),
|
||||
nn.GroupNorm(groups, chan),
|
||||
leaky_relu(),
|
||||
nn.Conv2d(chan, chan, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x) + x
|
||||
|
||||
class VQGanAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
dim_head = 64,
|
||||
heads = 8,
|
||||
dropout = 0.
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.scale = dim_head ** -0.5
|
||||
inner_dim = heads * dim_head
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.pre_norm = LayerNormChan(dim)
|
||||
|
||||
self.cpb = ContinuousPositionBias(dim = dim // 4, heads = heads)
|
||||
self.to_qkv = nn.Conv2d(dim, inner_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
|
||||
|
||||
def forward(self, x):
|
||||
h = self.heads
|
||||
height, width, residual = *x.shape[-2:], x.clone()
|
||||
|
||||
x = self.pre_norm(x)
|
||||
|
||||
q, k, v = self.to_qkv(x).chunk(3, dim = 1)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = h), (q, k, v))
|
||||
|
||||
sim = einsum('b h c i, b h c j -> b h i j', q, k) * self.scale
|
||||
|
||||
sim = self.cpb(sim)
|
||||
|
||||
attn = stable_softmax(sim, dim = -1)
|
||||
attn = self.dropout(attn)
|
||||
|
||||
out = einsum('b h i j, b h c j -> b h c i', attn, v)
|
||||
out = rearrange(out, 'b h c (x y) -> b (h c) x y', x = height, y = width)
|
||||
out = self.to_out(out)
|
||||
|
||||
return out + residual
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
num_resnet_blocks = 1,
|
||||
vgg = None,
|
||||
vq_codebook_size = 512,
|
||||
vq_decay = 0.8,
|
||||
vq_commitment_weight = 1.,
|
||||
vq_kmeans_init = True,
|
||||
vq_use_cosine_sim = True,
|
||||
use_attn = True,
|
||||
attn_dim_head = 64,
|
||||
attn_heads = 8,
|
||||
resnet_groups = 16,
|
||||
attn_dropout = 0.,
|
||||
first_conv_kernel_size = 5,
|
||||
use_vgg_and_gan = True,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
assert dim % resnet_groups == 0, f'dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)'
|
||||
|
||||
vq_kwargs, kwargs = groupby_prefix_and_trim('vq_', kwargs)
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
if layer_use_attn:
|
||||
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
for _ in range(layer_num_resnet_blocks):
|
||||
append(self.encoders, ResBlock(dim_out, groups = resnet_groups))
|
||||
prepend(self.decoders, GLUResBlock(dim_out, groups = resnet_groups))
|
||||
|
||||
if layer_use_attn:
|
||||
append(self.encoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
|
||||
|
||||
prepend(self.encoders, nn.Conv2d(channels, dim, first_conv_kernel_size, padding = first_conv_kernel_size // 2))
|
||||
append(self.decoders, nn.Conv2d(dim, channels, 1))
|
||||
|
||||
self.vq = VQ(
|
||||
dim = codebook_dim,
|
||||
codebook_size = vq_codebook_size,
|
||||
decay = vq_decay,
|
||||
commitment_weight = vq_commitment_weight,
|
||||
accept_image_fmap = True,
|
||||
kmeans_init = vq_kmeans_init,
|
||||
use_cosine_sim = vq_use_cosine_sim,
|
||||
**vq_kwargs
|
||||
)
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
|
||||
|
||||
# turn off GAN and perceptual loss if grayscale
|
||||
|
||||
self.vgg = None
|
||||
self.discr = None
|
||||
self.use_vgg_and_gan = use_vgg_and_gan
|
||||
|
||||
if not use_vgg_and_gan:
|
||||
return
|
||||
|
||||
# preceptual loss
|
||||
|
||||
if exists(vgg):
|
||||
self.vgg = vgg
|
||||
else:
|
||||
self.vgg = torchvision.models.vgg16(pretrained = True)
|
||||
self.vgg.classifier = nn.Sequential(*self.vgg.classifier[:-2])
|
||||
|
||||
# gan related losses
|
||||
|
||||
self.discr = Discriminator(dims = dims, channels = channels)
|
||||
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
|
||||
if vae_copy.use_vgg_and_gan:
|
||||
del vae_copy.discr
|
||||
del vae_copy.vgg
|
||||
|
||||
vae_copy.eval()
|
||||
return vae_copy.to(device)
|
||||
|
||||
@remove_vgg
|
||||
def state_dict(self, *args, **kwargs):
|
||||
return super().state_dict(*args, **kwargs)
|
||||
|
||||
@remove_vgg
|
||||
def load_state_dict(self, *args, **kwargs):
|
||||
return super().load_state_dict(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self.vq.codebook
|
||||
|
||||
def encode(self, fmap):
|
||||
for enc in self.encoders:
|
||||
fmap = enc(fmap)
|
||||
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
img,
|
||||
return_loss = False,
|
||||
return_discr_loss = False,
|
||||
return_recons = False,
|
||||
add_gradient_penalty = True
|
||||
):
|
||||
batch, channels, height, width, device = *img.shape, img.device
|
||||
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
|
||||
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
|
||||
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
|
||||
|
||||
# whether to return discriminator loss
|
||||
|
||||
if return_discr_loss:
|
||||
assert exists(self.discr), 'discriminator must exist to train it'
|
||||
|
||||
fmap.detach_()
|
||||
img.requires_grad_()
|
||||
|
||||
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
|
||||
|
||||
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
|
||||
|
||||
if add_gradient_penalty:
|
||||
gp = gradient_penalty(img, img_discr_logits)
|
||||
loss = discr_loss + gp
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
|
||||
# reconstruction loss
|
||||
|
||||
recon_loss = self.recon_loss_fn(fmap, img)
|
||||
|
||||
# early return if training on grayscale
|
||||
|
||||
if not self.use_vgg_and_gan:
|
||||
if return_recons:
|
||||
return recon_loss, fmap
|
||||
|
||||
return recon_loss
|
||||
|
||||
# perceptual loss
|
||||
|
||||
img_vgg_input = img
|
||||
fmap_vgg_input = fmap
|
||||
|
||||
if img.shape[1] == 1:
|
||||
# handle grayscale for vgg
|
||||
img_vgg_input, fmap_vgg_input = map(lambda t: repeat(t, 'b 1 ... -> b c ...', c = 3), (img_vgg_input, fmap_vgg_input))
|
||||
|
||||
img_vgg_feats = self.vgg(img_vgg_input)
|
||||
recon_vgg_feats = self.vgg(fmap_vgg_input)
|
||||
perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
|
||||
|
||||
# generator loss
|
||||
|
||||
gen_loss = self.gen_loss(self.discr(fmap))
|
||||
|
||||
# calculate adaptive weight
|
||||
|
||||
last_dec_layer = self.decoders[-1].weight
|
||||
|
||||
norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p = 2)
|
||||
norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p = 2)
|
||||
|
||||
adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
|
||||
adaptive_weight.clamp_(max = 1e4)
|
||||
|
||||
# combine losses
|
||||
|
||||
loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
|
||||
|
||||
if return_recons:
|
||||
return loss, fmap
|
||||
|
||||
return loss
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.28',
|
||||
version = '0.0.43',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -30,6 +30,7 @@ setup(
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
'vector-quantize-pytorch',
|
||||
'x-clip>=0.4.4',
|
||||
'youtokentome'
|
||||
],
|
||||
|
||||
Reference in New Issue
Block a user