From 723bf0abbad7b30f5fb0cdee368e8c3513377a8a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 19 Jul 2022 09:26:55 -0700 Subject: [PATCH] complete inpainting ability using inpaint_image and inpaint_mask passed into sample function for decoder --- README.md | 4 +- dalle2_pytorch/dalle2_pytorch.py | 88 ++++++++++++++++++++++++++++++-- dalle2_pytorch/version.py | 2 +- 3 files changed, 87 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 8ca1821..cad0dbb 100644 --- a/README.md +++ b/README.md @@ -1049,8 +1049,8 @@ Once built, images will be saved to the same directory the command is invoked - [x] test out grid attention in cascading ddpm locally, decide whether to keep or remove https://arxiv.org/abs/2204.01697 (keeping, seems to be fine) - [x] allow for unet to be able to condition non-cross attention style as well - [x] speed up inference, read up on papers (ddim) -- [ ] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 -- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - consider https://github.com/lucidrains/uformer-pytorch attention-based unet +- [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 +- [ ] try out the nested unet from https://arxiv.org/abs/2005.09007 after hearing several positive testimonies from researchers, for segmentation anyhow - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c1bf9f2..d6cd618 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -2415,20 +2415,51 @@ class Decoder(nn.Module): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): + def p_sample_loop_ddpm( + self, + unet, + shape, + image_embed, + noise_scheduler, + predict_x_start = False, + learned_variance = False, + clip_denoised = True, + lowres_cond_img = None, + text_encodings = None, + cond_scale = 1, + is_latent_diffusion = False, + lowres_noise_level = None, + inpaint_image = None, + inpaint_mask = None + ): device = self.device b = shape[0] img = torch.randn(shape, device = device) + if exists(inpaint_image): + inpaint_image = self.normalize_img(inpaint_image) + inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) + inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() + inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True) + inpaint_mask = inpaint_mask.bool() + if not is_latent_diffusion: lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): + times = torch.full((b,), i, device = device, dtype = torch.long) + + if exists(inpaint_image): + # following the repaint paper + # https://arxiv.org/abs/2201.09865 + noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = times) + img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) + img = self.p_sample( unet, img, - torch.full((b,), i, device = device, dtype = torch.long), + times, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, @@ -2440,11 +2471,32 @@ class Decoder(nn.Module): clip_denoised = clip_denoised ) + if exists(inpaint_image): + img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) + unnormalize_img = self.unnormalize_img(img) return unnormalize_img @torch.no_grad() - def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, cond_scale = 1, is_latent_diffusion = False, lowres_noise_level = None): + def p_sample_loop_ddim( + self, + unet, + shape, + image_embed, + noise_scheduler, + timesteps, + eta = 1., + predict_x_start = False, + learned_variance = False, + clip_denoised = True, + lowres_cond_img = None, + text_encodings = None, + cond_scale = 1, + is_latent_diffusion = False, + lowres_noise_level = None, + inpaint_image = None, + inpaint_mask = None + ): batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] @@ -2452,6 +2504,13 @@ class Decoder(nn.Module): times = list(reversed(times.int().tolist())) time_pairs = list(zip(times[:-1], times[1:])) + if exists(inpaint_image): + inpaint_image = self.normalize_img(inpaint_image) + inpaint_image = resize_image_to(inpaint_image, shape[-1], nearest = True) + inpaint_mask = rearrange(inpaint_mask, 'b h w -> b 1 h w').float() + inpaint_mask = resize_image_to(inpaint_mask, shape[-1], nearest = True) + inpaint_mask = inpaint_mask.bool() + img = torch.randn(shape, device = device) if not is_latent_diffusion: @@ -2463,6 +2522,12 @@ class Decoder(nn.Module): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + if exists(inpaint_image): + # following the repaint paper + # https://arxiv.org/abs/2201.09865 + noised_inpaint_image = noise_scheduler.q_sample(inpaint_image, t = time_cond) + img = (img * ~inpaint_mask) + (noised_inpaint_image * inpaint_mask) + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level) if learned_variance: @@ -2486,6 +2551,9 @@ class Decoder(nn.Module): c1 * noise + \ c2 * pred_noise + if exists(inpaint_image): + img = (img * ~inpaint_mask) + (inpaint_image * inpaint_mask) + img = self.unnormalize_img(img) return img @@ -2585,6 +2653,8 @@ class Decoder(nn.Module): cond_scale = 1., stop_at_unet_number = None, distributed = False, + inpaint_image = None, + inpaint_mask = None ): assert self.unconditional or exists(image_embed), 'image embed must be present on sampling from decoder unless if trained unconditionally' @@ -2598,6 +2668,8 @@ class Decoder(nn.Module): assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified' assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' + assert not (exists(inpaint_image) ^ exists(inpaint_mask)), 'inpaint_image and inpaint_mask (boolean mask of [batch, height, width]) must be both given for inpainting' + img = None is_cuda = next(self.parameters()).is_cuda @@ -2609,6 +2681,8 @@ class Decoder(nn.Module): context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() with context: + # prepare low resolution conditioning for upsamplers + lowres_cond_img = lowres_noise_level = None shape = (batch_size, channel, image_size, image_size) @@ -2619,12 +2693,16 @@ class Decoder(nn.Module): lowres_noise_level = torch.full((batch_size,), int(self.lowres_noise_sample_level * 1000), dtype = torch.long, device = self.device) lowres_cond_img, _ = lowres_cond.noise_image(lowres_cond_img, lowres_noise_level) + # latent diffusion + is_latent_diffusion = isinstance(vae, VQGanVAE) image_size = vae.get_encoded_fmap_size(image_size) shape = (batch_size, vae.encoded_dim, image_size, image_size) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) + # denoising loop for image + img = self.p_sample_loop( unet, shape, @@ -2638,7 +2716,9 @@ class Decoder(nn.Module): lowres_noise_level = lowres_noise_level, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, - timesteps = sample_timesteps + timesteps = sample_timesteps, + inpaint_image = inpaint_image, + inpaint_mask = inpaint_mask ) img = vae.decode(img) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index a636f70..826d20e 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.25.2' +__version__ = '0.26.0'