Compare commits

...

26 Commits

Author SHA1 Message Date
Phil Wang
27a33e1b20 complete contextmanager method for keeping only one unet in GPU during training or inference 2022-04-20 10:46:13 -07:00
Phil Wang
6f941a219a give time tokens a surface area of 2 tokens as default, make it so researcher can customize which unet actually is conditioned on image embeddings and/or text encodings 2022-04-20 10:04:47 -07:00
Phil Wang
ddde8ca1bf fix cosine bbeta schedule, thanks to @Zhengxinyang 2022-04-19 20:54:28 -07:00
Phil Wang
c26b77ad20 todo 2022-04-19 13:07:32 -07:00
Phil Wang
c5b4aab8e5 intent 2022-04-19 11:00:05 -07:00
Phil Wang
a35c309b5f add sparse attention layers in between convnext blocks in unet (grid like attention, used in mobilevit, maxvit [bytedance ai], as well as a growing number of attention-based GANs) 2022-04-19 09:49:03 -07:00
Phil Wang
55bdcb98b9 scaffold for latent diffusion 2022-04-19 09:26:58 -07:00
Phil Wang
82328f16cd same for text encodings for decoder ddpm training 2022-04-18 14:41:02 -07:00
Phil Wang
6fee4fce6e also allow for image embedding to be passed into the diffusion model, in the case one wants to generate image embedding once and then train multiple unets in one iteration 2022-04-18 14:00:38 -07:00
Phil Wang
a54e309269 prioritize todos, play project management 2022-04-18 13:28:01 -07:00
Phil Wang
c6bfd7fdc8 readme 2022-04-18 12:43:10 -07:00
Phil Wang
960a79857b use some magic just this once to remove the need for researchers to think 2022-04-18 12:40:43 -07:00
Phil Wang
7214df472d todo 2022-04-18 12:18:19 -07:00
Phil Wang
00ae50999b make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen 2022-04-18 12:04:31 -07:00
Phil Wang
6cddefad26 readme 2022-04-18 11:52:25 -07:00
Phil Wang
0332eaa6ff complete first pass at full cascading DDPM setup in Decoder, flexible enough to support one unet for testing 2022-04-18 11:44:56 -07:00
Phil Wang
1cce4225eb 0.0.18 2022-04-17 07:29:34 -07:00
Phil Wang
5ab0700bab Merge pull request #14 from kashif/loss-schedule
added huber loss and other schedulers
2022-04-17 07:29:10 -07:00
Kashif Rasul
b0f2fbaa95 schedule to Prior 2022-04-17 15:21:47 +02:00
Kashif Rasul
51361c2d15 added beta_schedule argument 2022-04-17 15:19:33 +02:00
Kashif Rasul
42d6e47387 added huber loss and other schedulers 2022-04-17 15:14:05 +02:00
Phil Wang
1e939153fb link to AssemblyAI explanation 2022-04-15 12:58:57 -07:00
Phil Wang
1abeb8918e personal project management for next week 2022-04-15 08:04:01 -07:00
Phil Wang
b423855483 commit to jax version 2022-04-15 07:16:25 -07:00
Phil Wang
c400d8758c prepare for cascading diffusion in unet, save the full progressive upsampling architecture to be built next week 2022-04-15 07:03:28 -07:00
Phil Wang
bece206699 fix bug thanks to @jihoonerd 2022-04-15 06:44:40 -07:00
5 changed files with 436 additions and 117 deletions

129
README.md
View File

@@ -2,7 +2,9 @@
## DALL-E 2 - Pytorch (wip)
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
<a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a> | <a href="https://www.youtube.com/watch?v=F1X4fHzF4mQ">AssemblyAI explainer</a>
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
@@ -12,9 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
@@ -182,7 +182,82 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within this framework as so
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 512, 512).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
```python
from dalle2_pytorch import DALLE2
@@ -261,7 +336,7 @@ loss.backward()
# decoder (with unet)
unet = Unet(
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
@@ -269,15 +344,25 @@ unet = Unet(
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
net = unet,
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -291,11 +376,13 @@ images = dalle2(
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
# save your image (in this example, of size 256x256)
```
Everything in this readme should run without error
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
@@ -320,12 +407,14 @@ Offer training wrappers
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007
## Citations
@@ -374,4 +463,12 @@ Offer training wrappers
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -6,4 +6,4 @@ def main():
@click.command()
@click.argument('text')
def dream(text):
return image
return 'not ready yet'

View File

@@ -1,6 +1,8 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
from contextlib import contextmanager
import torch
import torch.nn.functional as F
@@ -11,7 +13,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -29,6 +31,9 @@ def default(val, d):
return val
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -64,6 +69,15 @@ def freeze_model_and_make_eval_(model):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
shape = cast_tuple(image_size, 2)
orig_image_size = t.shape[-2:]
if orig_image_size == shape:
return t
return F.interpolate(t, size = shape, mode = mode)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -92,12 +106,35 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# diffusion prior
class RMSNorm(nn.Module):
@@ -427,10 +464,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
predict_x0 = True
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -446,11 +484,22 @@ class DiffusionPrior(nn.Module):
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
betas = cosine_beta_schedule(timesteps)
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -550,31 +599,6 @@ class DiffusionPrior(nn.Module):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -601,11 +625,39 @@ class DiffusionPrior(nn.Module):
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
return loss
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@@ -625,17 +677,6 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -758,6 +799,20 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module):
def __init__(
self,
@@ -766,14 +821,41 @@ class Unet(nn.Module):
image_embed_dim,
cond_dim = None,
num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
cond_on_image_embeds = False,
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
@@ -784,8 +866,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.image_to_cond = nn.Sequential(
@@ -795,6 +877,12 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
@@ -807,27 +895,32 @@ class Unet(nn.Module):
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity()
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
@@ -836,6 +929,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -856,29 +958,56 @@ class Unet(nn.Module):
time,
*,
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
batch_size, device = x.shape[0], x.device
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
time_tokens = self.time_mlp(time)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_tokens = self.image_to_cond(image_embed)
image_tokens = None
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
if self.cond_on_image_embeds:
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
@@ -888,30 +1017,38 @@ class Unet(nn.Module):
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, convnext2, downsample in self.downs:
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, mid_c)
x = self.mid_attn(x)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
x = upsample(x)
@@ -920,28 +1057,60 @@ class Unet(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
net,
unet,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1'
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.image_size = clip.image_size
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.cond_drop_prob = cond_drop_prob
betas = cosine_beta_schedule(timesteps)
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -973,11 +1142,26 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager
def one_unet_in_gpu(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
self.cuda()
self.unets.cpu()
unet = self.unets[index]
unet.cuda()
yield
self.unets.cpu()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -1004,8 +1188,9 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -1014,33 +1199,25 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
@torch.no_grad()
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -1049,16 +1226,17 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = self.net(
x_recon = unet(
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
)
@@ -1066,22 +1244,54 @@ class Decoder(nn.Module):
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
with self.one_unet_in_gpu(ind + 1):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
if not exists(image_embed):
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class

View File

@@ -0,0 +1,12 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.15',
version = '0.0.28',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',