From 0332eaa6ff7b5804a42495cbcedad0b31928e51f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 18 Apr 2022 11:44:56 -0700 Subject: [PATCH] complete first pass at full cascading DDPM setup in Decoder, flexible enough to support one unet for testing --- README.md | 102 ++++++++++++++++- dalle2_pytorch/dalle2_pytorch.py | 183 ++++++++++++++++++++----------- setup.py | 2 +- 3 files changed, 214 insertions(+), 73 deletions(-) diff --git a/README.md b/README.md index df31c77..1d0231a 100644 --- a/README.md +++ b/README.md @@ -182,6 +182,81 @@ loss.backward() # now the diffusion prior can generate image embeddings from the text embeddings ``` +In the paper, they actually used a recently discovered technique, from Jonathan Ho himself (original author of DDPMs, from which DALL-E2 is based). + +This can easily be used within the framework offered in this repository as so + +```python +import torch +from dalle2_pytorch import Unet, Decoder, CLIP + +# trained clip from step 1 + +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 1, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 1, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 +).cuda() + +# 2 unets for the decoder (a la cascading DDPM) + +unet1 = Unet( + dim = 16, + image_embed_dim = 512, + channels = 3, + dim_mults = (1, 2, 4, 8) +).cuda() + +unet2 = Unet( + dim = 16, + image_embed_dim = 512, + lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off) + cond_dim = 128, + channels = 3, + dim_mults = (1, 2, 4, 8, 16) +).cuda() + +# decoder, which contains the unet and clip + +decoder = Decoder( + clip = clip, + unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) + image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second + timesteps = 100, + cond_drop_prob = 0.2 +).cuda() + +# mock images (get a lot of this) + +images = torch.randn(4, 3, 512, 512).cuda() + +# feed images into decoder, specifying which unet you want to train +# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme + +loss = decoder(images, unet_number = 1) +loss.backward() + +loss = decoder(images, unet_number = 2) +loss.backward() + +# do the above for many steps for both unets + +# then it will learn to generate images based on the CLIP image embeddings + +# chaining the unets from lowest resolution to highest resolution (thus cascading) + +mock_image_embed = torch.randn(1, 512).cuda() +images = decoder.sample(mock_image_embed) # (1, 3, 512, 512) +``` + Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer) ```python @@ -261,7 +336,7 @@ loss.backward() # decoder (with unet) -unet = Unet( +unet1 = Unet( dim = 128, image_embed_dim = 512, cond_dim = 128, @@ -269,15 +344,26 @@ unet = Unet( dim_mults=(1, 2, 4, 8) ).cuda() +unet2 = Unet( + dim = 16, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + dim_mults = (1, 2, 4, 8, 16), + lowres_cond = True +).cuda() + decoder = Decoder( - net = unet, + unet = (unet1, unet2), + image_sizes = (128, 256), clip = clip, timesteps = 100, cond_drop_prob = 0.2 ).cuda() -loss = decoder(images) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much -loss.backward() +for unet_number in (1, 2): + loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much + loss.backward() # do above for many steps @@ -291,11 +377,13 @@ images = dalle2( cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition) ) -# save your image +# save your image (in this example, of size 256x256) ``` Everything in this readme should run without error +You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings + For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. ## CLI Usage (work in progress) @@ -321,7 +409,9 @@ Offer training wrappers - [x] make sure it works end to end to produce an output tensor, taking a single gradient step - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) -- [ ] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions +- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions +- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest +- [ ] make unet more configurable - [ ] train on a toy task, offer in colab - [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 84a7528..ea69b5b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -29,6 +29,9 @@ def default(val, d): return val return d() if isfunction(d) else d +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else ((val,) * length) + def eval_decorator(fn): def inner(model, *args, **kwargs): was_training = model.training @@ -64,6 +67,15 @@ def freeze_model_and_make_eval_(model): def l2norm(t): return F.normalize(t, dim = -1) +def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight + shape = cast_tuple(image_size, 2) + orig_image_size = t.shape[-2:] + + if orig_image_size == shape: + return t + + return F.interpolate(t, size = shape, mode = mode) + # classifier free guidance functions def prob_mask_like(shape, prob, device): @@ -585,31 +597,6 @@ class DiffusionPrior(nn.Module): img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) return img - @torch.no_grad() - def sample(self, text, num_samples_per_batch = 2): - # in the paper, what they did was - # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP - text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) - - batch_size = text.shape[0] - image_embed_dim = self.image_embed_dim - - text_cond = self.get_text_cond(text) - - image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) - text_embeds = text_cond['text_embed'] - - text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) - image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch) - - text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds)) - top_sim_indices = text_image_sims.topk(k = 1).indices - - top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim) - - top_image_embeds = image_embeds.gather(1, top_sim_indices) - return rearrange(top_image_embeds, 'b 1 d -> b d') - def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -643,6 +630,32 @@ class DiffusionPrior(nn.Module): return loss + @torch.no_grad() + @eval_decorator + def sample(self, text, num_samples_per_batch = 2): + # in the paper, what they did was + # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP + text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) + + batch_size = text.shape[0] + image_embed_dim = self.image_embed_dim + + text_cond = self.get_text_cond(text) + + image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + text_embeds = text_cond['text_embed'] + + text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch) + + text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds)) + top_sim_indices = text_image_sims.topk(k = 1).indices + + top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim) + + top_image_embeds = image_embeds.gather(1, top_sim_indices) + return rearrange(top_image_embeds, 'b 1 d -> b d') + def forward(self, text, image, *args, **kwargs): b, device, img_size, = image.shape[0], image.device, self.image_size check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) @@ -797,7 +810,8 @@ class Unet(nn.Module): channels = 3, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond_upsample_mode = 'bilinear', - blur_sigma = 0.1 + blur_sigma = 0.1, + attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) ): super().__init__() @@ -847,27 +861,30 @@ class Unet(nn.Module): num_resolutions = len(in_out) for ind, (dim_in, dim_out) in enumerate(in_out): + is_first = ind == 0 is_last = ind >= (num_resolutions - 1) + layer_cond_dim = cond_dim if not is_first else None self.downs.append(nn.ModuleList([ ConvNextBlock(dim_in, dim_out, norm = ind != 0), - ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim), + ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), Downsample(dim_out) if not is_last else nn.Identity() ])) mid_dim = dims[-1] self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) - self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) + self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): - is_last = ind >= (num_resolutions - 1) + is_last = ind >= (num_resolutions - 2) + layer_cond_dim = cond_dim if not is_last else None self.ups.append(nn.ModuleList([ - ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim), - ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim), - Upsample(dim_in) if not is_last else nn.Identity() + ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), + ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), + Upsample(dim_in) ])) out_dim = default(out_dim, channels) @@ -904,14 +921,14 @@ class Unet(nn.Module): # add low resolution conditioning, if present - assert not self.lowres_cond and not exists(lowres_cond_img), 'low resolution conditioning image must be present' + assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' if exists(lowres_cond_img): if self.training: # when training, blur the low resolution conditional image lowres_cond_img = self.lowres_cond_blur(lowres_cond_img) - lowres_cond_img = F.interpolate(lowres_cond_img, size = x.shape[-2:], mode = self.lowres_cond_upsample_mode) + lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode) x = torch.cat((x, lowres_cond_img), dim = 1) # time conditioning @@ -964,7 +981,10 @@ class Unet(nn.Module): x = downsample(x) x = self.mid_block1(x, mid_c) - x = self.mid_attn(x) + + if exists(self.mid_attn): + x = self.mid_attn(x) + x = self.mid_block2(x, mid_c) for convnext, convnext2, upsample in self.ups: @@ -978,22 +998,32 @@ class Unet(nn.Module): class Decoder(nn.Module): def __init__( self, - net, + unet, *, clip, - timesteps=1000, - cond_drop_prob=0.2, - loss_type="l1", - beta_schedule="cosine", + timesteps = 1000, + cond_drop_prob = 0.2, + loss_type = 'l1', + beta_schedule = 'cosine', + image_sizes = None # for cascading ddpm, image size at each stage ): super().__init__() assert isinstance(clip, CLIP) freeze_model_and_make_eval_(clip) self.clip = clip - - self.net = net + self.clip_image_size = clip.image_size self.channels = clip.image_channels - self.image_size = clip.image_size + + self.unets = cast_tuple(unet) + image_sizes = default(image_sizes, (clip.image_size,)) + image_sizes = tuple(sorted(set(image_sizes))) + + assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}' + self.image_sizes = image_sizes + + lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) + assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' + self.cond_drop_prob = cond_drop_prob if beta_schedule == "cosine": @@ -1048,6 +1078,7 @@ class Decoder(nn.Module): return text_encodings[:, 1:] def get_image_embed(self, image): + image = resize_image_to(image, self.clip_image_size) image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] image_embed = self.clip.to_visual_latent(image_cls) @@ -1074,8 +1105,9 @@ class Decoder(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, x, t, image_embed, text_encodings = None, clip_denoised = True, cond_scale = 1.): - x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale)) + def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.): + pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise) if clip_denoised: x_recon.clamp_(-1., 1.) @@ -1084,33 +1116,25 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, x, t, image_embed, text_encodings = None, cond_scale = 1., clip_denoised = True, repeat_noise = False): + def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, clip_denoised = clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, shape, image_embed, text_encodings = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1): device = self.betas.device b = shape[0] - img = torch.randn(shape, device=device) + img = torch.randn(shape, device = device) - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) + for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) return img - @torch.no_grad() - def sample(self, image_embed, text = None, cond_scale = 1.): - batch_size = image_embed.shape[0] - image_size = self.image_size - channels = self.channels - text_encodings = self.get_text_encodings(text) if exists(text) else None - return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale) - def q_sample(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -1119,16 +1143,17 @@ class Decoder(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) - def p_losses(self, x_start, t, *, image_embed, text_encodings = None, noise = None): + def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) - x_recon = self.net( + x_recon = unet( x_noisy, t, image_embed = image_embed, text_encodings = text_encodings, + lowres_cond_img = lowres_cond_img, cond_drop_prob = self.cond_drop_prob ) @@ -1143,17 +1168,43 @@ class Decoder(nn.Module): return loss - def forward(self, image, text = None): - b, device, img_size, = image.shape[0], image.device, self.image_size - check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels) + @torch.no_grad() + @eval_decorator + def sample(self, image_embed, text = None, cond_scale = 1.): + batch_size = image_embed.shape[0] + channels = self.channels + + text_encodings = self.get_text_encodings(text) if exists(text) else None + + img = None + for unet, image_size in tqdm(zip(self.unets, self.image_sizes)): + shape = (batch_size, channels, image_size, image_size) + img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) + + return img + + def forward(self, image, text = None, unet_number = None): + assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' + unet_number = default(unet_number, 1) + assert 1 <= unet_number <= len(self.unets) + + index = unet_number - 1 + unet = self.unets[index] + target_image_size = self.image_sizes[index] + + b, c, h, w, device, = *image.shape, image.device + + check_shape(image, 'b c h w', c = self.channels) + assert h >= target_image_size and w >= target_image_size times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) image_embed = self.get_image_embed(image) text_encodings = self.get_text_encodings(text) if exists(text) else None - loss = self.p_losses(image, times, image_embed = image_embed, text_encodings = text_encodings) - return loss + lowres_cond_img = image if index > 0 else None + ddpm_image = resize_image_to(image, target_image_size) + return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) # main class diff --git a/setup.py b/setup.py index 2442f0c..fcf89e3 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.18', + version = '0.0.20', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',