Compare commits

...

8 Commits

3 changed files with 254 additions and 83 deletions

109
README.md
View File

@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
@@ -182,7 +182,82 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within this framework as so
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 512, 512).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
```python
from dalle2_pytorch import DALLE2
@@ -261,7 +336,7 @@ loss.backward()
# decoder (with unet)
unet = Unet(
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
@@ -269,15 +344,25 @@ unet = Unet(
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
net = unet,
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -291,11 +376,13 @@ images = dalle2(
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
# save your image (in this example, of size 256x256)
```
Everything in this readme should run without error
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
@@ -321,11 +408,11 @@ Offer training wrappers
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
## Citations

View File

@@ -1,6 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
import torch
import torch.nn.functional as F
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -29,6 +30,9 @@ def default(val, d):
return val
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -64,6 +68,15 @@ def freeze_model_and_make_eval_(model):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
shape = cast_tuple(image_size, 2)
orig_image_size = t.shape[-2:]
if orig_image_size == shape:
return t
return F.interpolate(t, size = shape, mode = mode)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -585,31 +598,6 @@ class DiffusionPrior(nn.Module):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -643,6 +631,32 @@ class DiffusionPrior(nn.Module):
return loss
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@@ -797,15 +811,23 @@ class Unet(nn.Module):
channels = 3,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1
blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -847,27 +869,30 @@ class Unet(nn.Module):
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity()
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
@@ -876,6 +901,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -898,20 +932,24 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
batch_size, device = x.shape[0], x.device
# add low resolution conditioning, if present
assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present'
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode)
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
@@ -964,7 +1002,10 @@ class Unet(nn.Module):
x = downsample(x)
x = self.mid_block1(x, mid_c)
x = self.mid_attn(x)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
@@ -978,22 +1019,42 @@ class Unet(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
net,
unet,
*,
clip,
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
beta_schedule="cosine",
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.image_size = clip.image_size
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.cond_drop_prob = cond_drop_prob
if beta_schedule == "cosine":
@@ -1048,6 +1109,7 @@ class Decoder(nn.Module):
return text_encodings[:, 1:]
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -1074,8 +1136,9 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale))
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -1084,33 +1147,25 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
@torch.no_grad()
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -1119,16 +1174,17 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = self.net(
x_recon = unet(
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
)
@@ -1143,17 +1199,45 @@ class Decoder(nn.Module):
return loss
def forward(self, image, text = None):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img
def forward(self, image, text = None, image_embed = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
if not exists(image_embed):
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings)
return loss
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.18',
version = '0.0.23',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',