From 9440411954dcb05ccbc8d09e654f21fb60f1e88e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 12 Aug 2022 12:20:51 -0700 Subject: [PATCH] make self conditioning technique work with diffusion prior --- dalle2_pytorch/dalle2_pytorch.py | 50 ++++++++++++++++++++++++++------ dalle2_pytorch/version.py | 2 +- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 5375a24..5de4a90 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -937,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module): num_image_embeds = 1, num_text_embeds = 1, max_text_len = 256, + self_cond = False, **kwargs ): super().__init__() + self.dim = dim + self.num_time_embeds = num_time_embeds self.num_image_embeds = num_image_embeds self.num_text_embeds = num_text_embeds @@ -967,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module): self.max_text_len = max_text_len self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim)) + # whether to use self conditioning, Hinton's group's new ddpm technique + + self.self_cond = self_cond + def forward_with_cond_scale( self, *args, @@ -988,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module): *, text_embed, text_encodings = None, + self_cond = None, cond_drop_prob = 0. ): batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds + # setup self conditioning + + self_cond = None + if self.self_cond: + self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype)) + # in section 2.2, last paragraph # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" @@ -1043,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module): # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # but let's just do it right - attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds + attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query time_embed = self.to_time_embeds(diffusion_timesteps) learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch) + if self.self_cond: + learned_queries = torch.cat((image_embed, self_cond), dim = -2) + tokens = torch.cat(( text_encodings, text_embed, @@ -1151,10 +1168,10 @@ class DiffusionPrior(nn.Module): def l2norm_clamp_embed(self, image_embed): return l2norm(image_embed) * self.image_embed_scale - def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): + def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' - pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond) + pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond) if self.predict_x_start: x_start = pred @@ -1168,28 +1185,33 @@ class DiffusionPrior(nn.Module): x_start = l2norm(x_start) * self.image_embed_scale model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance + return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.): + def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return pred, x_start @torch.no_grad() def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.): batch, device = shape[0], self.device + image_embed = torch.randn(shape, device = device) + x_start = None # for self-conditioning if self.init_image_embed_l2norm: image_embed = l2norm(image_embed) * self.image_embed_scale for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): times = torch.full((batch,), i, device = device, dtype = torch.long) - image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) + + self_cond = x_start if self.net.self_cond else None + image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale) if self.sampling_final_clamp_l2norm and self.predict_x_start: image_embed = self.l2norm_clamp_embed(image_embed) @@ -1207,6 +1229,8 @@ class DiffusionPrior(nn.Module): image_embed = torch.randn(shape, device = device) + x_start = None # for self-conditioning + if self.init_image_embed_l2norm: image_embed = l2norm(image_embed) * self.image_embed_scale @@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module): time_cond = torch.full((batch,), time, device = device, dtype = torch.long) - pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond) + self_cond = x_start if self.net.self_cond else None + + pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond) if self.predict_x_start: x_start = pred @@ -1260,9 +1286,15 @@ class DiffusionPrior(nn.Module): image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) + self_cond = None + if self.net.self_cond and random.random() < 0.5: + with torch.no_grad(): + self_cond = self.net(image_embed_noisy, times, **text_cond).detach() + pred = self.net( image_embed_noisy, times, + self_cond = self_cond, cond_drop_prob = self.cond_drop_prob, **text_cond ) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index bcd8d54..bb64aa4 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.0' +__version__ = '1.6.1'