From ecf9e8027d89ea840f6822182a014224e660c9df Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 15 May 2022 19:09:38 -0700 Subject: [PATCH] make sure classifier free guidance is used only if conditional dropout is present on the DiffusionPrior and Decoder classes. also make sure prior can have a different conditional scale than decoder --- dalle2_pytorch/dalle2_pytorch.py | 29 ++++++++++++++++++----------- setup.py | 2 +- 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a8fc6b8..4ec9710 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -901,6 +901,7 @@ class DiffusionPrior(BaseGaussianDiffusion): self.channels = default(image_channels, lambda: clip.image_channels) self.cond_drop_prob = cond_drop_prob + self.can_classifier_guidance = cond_drop_prob > 0. self.condition_on_text_encodings = condition_on_text_encodings # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. @@ -914,8 +915,10 @@ class DiffusionPrior(BaseGaussianDiffusion): self.training_clamp_l2norm = training_clamp_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm - def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): - pred = self.net(x, t, **text_cond) + def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): + assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' + + pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) if self.predict_x_start: x_recon = pred @@ -934,16 +937,16 @@ class DiffusionPrior(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance @torch.inference_mode() - def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False): + def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False, cond_scale = 1.): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised) + model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.inference_mode() - def p_sample_loop(self, shape, text_cond): + def p_sample_loop(self, shape, text_cond, cond_scale = 1.): device = self.betas.device b = shape[0] @@ -954,7 +957,7 @@ class DiffusionPrior(BaseGaussianDiffusion): for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): times = torch.full((b,), i, device = device, dtype = torch.long) - image_embed = self.p_sample(image_embed, times, text_cond = text_cond) + image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) return image_embed @@ -980,19 +983,19 @@ class DiffusionPrior(BaseGaussianDiffusion): @torch.inference_mode() @eval_decorator - def sample_batch_size(self, batch_size, text_cond): + def sample_batch_size(self, batch_size, text_cond, cond_scale = 1.): device = self.betas.device shape = (batch_size, self.image_embed_dim) img = torch.randn(shape, device = device) for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): - img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond) + img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale) return img @torch.inference_mode() @eval_decorator - def sample(self, text, num_samples_per_batch = 2): + def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.): # in the paper, what they did was # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) @@ -1007,7 +1010,7 @@ class DiffusionPrior(BaseGaussianDiffusion): if self.condition_on_text_encodings: text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} - image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond) + image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale) # retrieve original unscaled image embed @@ -1793,6 +1796,7 @@ class Decoder(BaseGaussianDiffusion): self.image_cond_drop_prob = image_cond_drop_prob self.text_cond_drop_prob = text_cond_drop_prob + self.can_classifier_guidance = image_cond_drop_prob > 0. or text_cond_drop_prob > 0. # whether to clip when sampling @@ -1819,6 +1823,8 @@ class Decoder(BaseGaussianDiffusion): unet.cpu() def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): + assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' + pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) if learned_variance: @@ -2094,6 +2100,7 @@ class DALLE2(nn.Module): self, text, cond_scale = 1., + prior_cond_scale = 1., return_pil_images = False ): device = next(self.parameters()).device @@ -2103,7 +2110,7 @@ class DALLE2(nn.Module): text = [text] if not isinstance(text, (list, tuple)) else text text = tokenizer.tokenize(text).to(device) - image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) + image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples, cond_scale = prior_cond_scale) text_cond = text if self.decoder_need_text_cond else None images = self.decoder.sample(image_embed, text = text_cond, cond_scale = cond_scale) diff --git a/setup.py b/setup.py index 9650985..41e60e8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.37', + version = '0.2.38', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',