diff --git a/README.md b/README.md index 3712257..6a85973 100644 --- a/README.md +++ b/README.md @@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d } ``` +```bibtex +@article{Salimans2022ProgressiveDF, + title = {Progressive Distillation for Fast Sampling of Diffusion Models}, + author = {Tim Salimans and Jonathan Ho}, + journal = {ArXiv}, + year = {2022}, + volume = {abs/2202.00512} +} +``` + *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's paper diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 534b3dc..bc71e91 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -619,7 +619,7 @@ class NoiseScheduler(nn.Module): posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance_clipped - def q_sample(self, x_start, t, noise=None): + def q_sample(self, x_start, t, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) return ( @@ -627,6 +627,12 @@ class NoiseScheduler(nn.Module): extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise ) + def calculate_v(self, x_start, t, noise = None): + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + def q_sample_from_to(self, x_from, from_t, to_t, noise = None): shape = x_from.shape noise = default(noise, lambda: torch.randn_like(x_from)) @@ -638,6 +644,12 @@ class NoiseScheduler(nn.Module): return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha + def predict_start_from_v(self, x_t, t, v): + return ( + extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + def predict_start_from_noise(self, x_t, t, noise): return ( extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - @@ -1146,6 +1158,7 @@ class DiffusionPrior(nn.Module): image_cond_drop_prob = None, loss_type = "l2", predict_x_start = True, + predict_v = False, beta_schedule = "cosine", condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) @@ -1197,6 +1210,7 @@ class DiffusionPrior(nn.Module): # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. self.predict_x_start = predict_x_start + self.predict_v = predict_v # takes precedence over predict_x_start # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 @@ -1226,7 +1240,9 @@ class DiffusionPrior(nn.Module): pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond) - if self.predict_x_start: + if self.predict_v: + x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred) + elif self.predict_x_start: x_start = pred else: x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) @@ -1299,7 +1315,9 @@ class DiffusionPrior(nn.Module): # derive x0 - if self.predict_x_start: + if self.predict_v: + x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred) + elif self.predict_x_start: x_start = pred else: x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise) @@ -1314,7 +1332,7 @@ class DiffusionPrior(nn.Module): # predict noise - if self.predict_x_start: + if self.predict_x_start or self.predict_v: pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start) else: pred_noise = pred @@ -1372,7 +1390,12 @@ class DiffusionPrior(nn.Module): if self.predict_x_start and self.training_clamp_l2norm: pred = self.l2norm_clamp_embed(pred) - target = noise if not self.predict_x_start else image_embed + if self.predict_v: + target = self.noise_scheduler.calculate_v(image_embed, times, noise) + elif self.predict_x_start: + target = image_embed + else: + target = noise loss = self.noise_scheduler.loss_fn(pred, target) return loss @@ -2448,6 +2471,7 @@ class Decoder(nn.Module): loss_type = 'l2', beta_schedule = None, predict_x_start = False, + predict_v = False, predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) @@ -2620,6 +2644,10 @@ class Decoder(nn.Module): self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) + # predict v + + self.predict_v = cast_tuple(predict_v, len(unets)) + # input image range self.input_image_range = (-1. if not auto_normalize_img else 0., 1.) @@ -2731,14 +2759,16 @@ class Decoder(nn.Module): x = x.clamp(-s, s) / s return x - def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): + def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level)) pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output) - if predict_x_start: + if predict_v: + x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred) + elif predict_x_start: x_start = pred else: x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) @@ -2765,9 +2795,9 @@ class Decoder(nn.Module): return model_mean, posterior_variance, posterior_log_variance, x_start @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): + def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) @@ -2782,6 +2812,7 @@ class Decoder(nn.Module): image_embed, noise_scheduler, predict_x_start = False, + predict_v = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, @@ -2840,6 +2871,7 @@ class Decoder(nn.Module): lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level, predict_x_start = predict_x_start, + predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, clip_denoised = clip_denoised @@ -2865,6 +2897,7 @@ class Decoder(nn.Module): timesteps, eta = 1., predict_x_start = False, + predict_v = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, @@ -2926,7 +2959,9 @@ class Decoder(nn.Module): # predict x0 - if predict_x_start: + if predict_v: + x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred) + elif predict_x_start: x_start = pred else: x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) @@ -2938,8 +2973,8 @@ class Decoder(nn.Module): # predict noise - if predict_x_start: - pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) + if predict_x_start or predict_v: + pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start) else: pred_noise = pred @@ -2975,7 +3010,7 @@ class Decoder(nn.Module): return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) - def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): + def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None): noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] @@ -3020,7 +3055,12 @@ class Decoder(nn.Module): pred, _ = self.parse_unet_output(learned_variance, unet_output) - target = noise if not predict_x_start else x_start + if predict_v: + target = noise_scheduler.calculate_v(x_start, times, noise) + elif predict_x_start: + target = x_start + else: + target = noise loss = noise_scheduler.loss_fn(pred, target, reduction = 'none') loss = reduce(loss, 'b ... -> b (...)', 'mean') @@ -3106,7 +3146,7 @@ class Decoder(nn.Module): num_unets = self.num_unets cond_scale = cast_tuple(cond_scale, num_unets) - for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): + for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)): if unet_number < start_at_unet_number: continue # It's the easiest way to do it @@ -3142,6 +3182,7 @@ class Decoder(nn.Module): text_encodings = text_encodings, cond_scale = unet_cond_scale, predict_x_start = predict_x_start, + predict_v = predict_v, learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img, @@ -3181,6 +3222,7 @@ class Decoder(nn.Module): lowres_conditioner = self.lowres_conds[unet_index] target_image_size = self.image_sizes[unet_index] predict_x_start = self.predict_x_start[unet_index] + predict_v = self.predict_v[unet_index] random_crop_size = self.random_crop_sizes[unet_index] learned_variance = self.learned_variance[unet_index] b, c, h, w, device, = *image.shape, image.device @@ -3219,7 +3261,7 @@ class Decoder(nn.Module): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) + losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level) if not return_lowres_cond_image: return losses diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index c283076..522ba08 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.10.9' +__version__ = '1.11.1'