From b7e22f7da0bb4ec2b38eddae7f481dd917718591 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 9 Jul 2022 17:25:34 -0700 Subject: [PATCH] complete ddim integration of diffusion prior as well as decoder for each unet, feature complete for https://github.com/lucidrains/DALLE2-pytorch/issues/157 --- README.md | 4 +- dalle2_pytorch/dalle2_pytorch.py | 200 ++++++++++++++++++++++++++----- dalle2_pytorch/train_configs.py | 2 + dalle2_pytorch/trainer.py | 12 +- dalle2_pytorch/version.py | 2 +- 5 files changed, 186 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index aabb062..abb1d3a 100644 --- a/README.md +++ b/README.md @@ -583,6 +583,7 @@ unet1 = Unet( cond_dim = 128, channels = 3, dim_mults=(1, 2, 4, 8), + text_embed_dim = 512, cond_on_text_encodings = True # set to True for any unets that need to be conditioned on text encodings (ex. first unet in cascade) ).cuda() @@ -598,7 +599,8 @@ decoder = Decoder( unet = (unet1, unet2), image_sizes = (128, 256), clip = clip, - timesteps = 100, + timesteps = 1000, + sample_timesteps = (250, 27), image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5 ).cuda() diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 48b332c..6e32b0c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -505,6 +505,12 @@ class NoiseScheduler(nn.Module): extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) + def predict_noise_from_start(self, x_t, t, x0): + return ( + (x0 - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t) / \ + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + ) + def p2_reweigh_loss(self, loss, times): if not self.has_p2_loss_reweighting: return loss @@ -911,6 +917,7 @@ class DiffusionPrior(nn.Module): image_size = None, image_channels = 3, timesteps = 1000, + sample_timesteps = None, cond_drop_prob = 0., loss_type = "l2", predict_x_start = True, @@ -924,6 +931,8 @@ class DiffusionPrior(nn.Module): ): super().__init__() + self.sample_timesteps = sample_timesteps + self.noise_scheduler = NoiseScheduler( beta_schedule = beta_schedule, timesteps = timesteps, @@ -978,8 +987,6 @@ class DiffusionPrior(nn.Module): if self.predict_x_start: x_recon = pred - # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this - # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken else: x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) @@ -1002,21 +1009,75 @@ class DiffusionPrior(nn.Module): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, shape, text_cond, cond_scale = 1.): - device = self.device - - b = shape[0] - image_embed = torch.randn(shape, device=device) + def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.): + batch, device = shape[0], self.device + image_embed = torch.randn(shape, device = device) if self.init_image_embed_l2norm: image_embed = l2norm(image_embed) * self.image_embed_scale for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): - times = torch.full((b,), i, device = device, dtype = torch.long) + times = torch.full((batch,), i, device = device, dtype = torch.long) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) return image_embed + @torch.no_grad() + def p_sample_loop_ddim(self, shape, text_cond, *, timesteps, eta = 1., cond_scale = 1.): + batch, device, alphas, total_timesteps = shape[0], self.device, self.noise_scheduler.alphas_cumprod_prev, self.noise_scheduler.num_timesteps + + times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] + + times = list(reversed(times.int().tolist())) + time_pairs = list(zip(times[:-1], times[1:])) + + image_embed = torch.randn(shape, device = device) + + if self.init_image_embed_l2norm: + image_embed = l2norm(image_embed) * self.image_embed_scale + + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): + alpha = alphas[time] + alpha_next = alphas[time_next] + + time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + + pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond) + + if self.predict_x_start: + x_start = pred + pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = pred) + else: + x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred) + pred_noise = pred + + if not self.predict_x_start: + x_start.clamp_(-1., 1.) + + if self.predict_x_start and self.sampling_clamp_l2norm: + x_start = l2norm(x_start) * self.image_embed_scale + + c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() + new_noise = torch.randn_like(image_embed) + + img = x_start * alpha_next.sqrt() + \ + c1 * new_noise + \ + c2 * pred_noise + + return image_embed + + @torch.no_grad() + def p_sample_loop(self, *args, timesteps = None, **kwargs): + timesteps = default(timesteps, self.noise_scheduler.num_timesteps) + assert timesteps <= self.noise_scheduler.num_timesteps + is_ddim = timesteps < self.noise_scheduler.num_timesteps + + if not is_ddim: + return self.p_sample_loop_ddpm(*args, **kwargs) + + return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) + def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) @@ -1051,7 +1112,15 @@ class DiffusionPrior(nn.Module): @torch.no_grad() @eval_decorator - def sample(self, text, num_samples_per_batch = 2, cond_scale = 1.): + def sample( + self, + text, + num_samples_per_batch = 2, + cond_scale = 1., + timesteps = None + ): + timesteps = default(timesteps, self.sample_timesteps) + # in the paper, what they did was # sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch) @@ -1066,7 +1135,7 @@ class DiffusionPrior(nn.Module): if self.condition_on_text_encodings: text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask} - image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale) + image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond, cond_scale = cond_scale, timesteps = timesteps) # retrieve original unscaled image embed @@ -1853,6 +1922,7 @@ class Decoder(nn.Module): channels = 3, vae = tuple(), timesteps = 1000, + sample_timesteps = None, image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5, loss_type = 'l2', @@ -1876,7 +1946,8 @@ class Decoder(nn.Module): use_dynamic_thres = False, # from the Imagen paper dynamic_thres_percentile = 0.9, p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended - p2_loss_weight_k = 1 + p2_loss_weight_k = 1, + ddim_sampling_eta = 1. # can be set to 0. for deterministic sampling afaict ): super().__init__() @@ -1956,6 +2027,11 @@ class Decoder(nn.Module): self.unets.append(one_unet) self.vaes.append(one_vae.copy_for_eval()) + # sampling timesteps, defaults to non-ddim with full timesteps sampling + + self.sample_timesteps = cast_tuple(sample_timesteps, num_unets) + self.ddim_sampling_eta = ddim_sampling_eta + # create noise schedulers per unet if not exists(beta_schedule): @@ -1966,7 +2042,9 @@ class Decoder(nn.Module): self.noise_schedulers = nn.ModuleList([]) - for unet_beta_schedule, unet_p2_loss_weight_gamma in zip(beta_schedule, p2_loss_weight_gamma): + for ind, (unet_beta_schedule, unet_p2_loss_weight_gamma, sample_timesteps) in enumerate(zip(beta_schedule, p2_loss_weight_gamma, self.sample_timesteps)): + assert sample_timesteps <= timesteps, f'sampling timesteps {sample_timesteps} must be less than or equal to the number of training timesteps {timesteps} for unet {ind + 1}' + noise_scheduler = NoiseScheduler( beta_schedule = unet_beta_schedule, timesteps = timesteps, @@ -2067,6 +2145,26 @@ class Decoder(nn.Module): for unet, device in zip(self.unets, devices): unet.to(device) + def dynamic_threshold(self, x): + """ proposed in https://arxiv.org/abs/2205.11487 as an improved clamping in the setting of classifier free guidance """ + + # s is the threshold amount + # static thresholding would just be s = 1 + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim = -1 + ) + + s.clamp_(min = 1.) + s = s.view(-1, *((1,) * (x.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x = x.clamp(-s, s) / s + return x + def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' @@ -2081,21 +2179,7 @@ class Decoder(nn.Module): x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised: - # s is the threshold amount - # static thresholding would just be s = 1 - s = 1. - if self.use_dynamic_thres: - s = torch.quantile( - rearrange(x_recon, 'b ... -> b (...)').abs(), - self.dynamic_thres_percentile, - dim = -1 - ) - - s.clamp_(min = 1.) - s = s.view(-1, *((1,) * (x_recon.ndim - 1))) - - # clip by threshold, depending on whether static or dynamic - x_recon = x_recon.clamp(-s, s) / s + x_recon = self.dynamic_threshold(x_recon) model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) @@ -2125,7 +2209,7 @@ class Decoder(nn.Module): return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): + def p_sample_loop_ddpm(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): device = self.device b = shape[0] @@ -2153,6 +2237,61 @@ class Decoder(nn.Module): unnormalize_img = self.unnormalize_img(img) return unnormalize_img + @torch.no_grad() + def p_sample_loop_ddim(self, unet, shape, image_embed, noise_scheduler, timesteps, eta = 1., predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): + batch, device, total_timesteps, alphas, eta = shape[0], self.device, noise_scheduler.num_timesteps, noise_scheduler.alphas_cumprod_prev, self.ddim_sampling_eta + + times = torch.linspace(0., total_timesteps, steps = timesteps + 2)[:-1] + + times = list(reversed(times.int().tolist())) + time_pairs = list(zip(times[:-1], times[1:])) + + img = torch.randn(shape, device = device) + + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step'): + alpha = alphas[time] + alpha_next = alphas[time_next] + + time_cond = torch.full((batch,), time, device = device, dtype = torch.long) + + pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + + if learned_variance: + pred, _ = pred.chunk(2, dim = 1) + + if predict_x_start: + x_start = pred + pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred) + else: + x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred) + pred_noise = pred + + if clip_denoised: + x_start = self.dynamic_threshold(x_start) + + c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() + + img = x_start * alpha_next.sqrt() + \ + c1 * torch.randn_like(img) + \ + c2 * pred_noise + + img = self.unnormalize_img(img) + return img + + @torch.no_grad() + def p_sample_loop(self, *args, noise_scheduler, timesteps = None, **kwargs): + num_timesteps = noise_scheduler.num_timesteps + + timesteps = default(timesteps, num_timesteps) + assert timesteps <= num_timesteps + is_ddim = timesteps < num_timesteps + + if not is_ddim: + return self.p_sample_loop_ddpm(*args, noise_scheduler = noise_scheduler, **kwargs) + + return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs) + def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): noise = default(noise, lambda: torch.randn_like(x_start)) @@ -2253,7 +2392,7 @@ class Decoder(nn.Module): img = None is_cuda = next(self.parameters()).is_cuda - for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)): + for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, sample_timesteps in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.sample_timesteps)): context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() @@ -2282,7 +2421,8 @@ class Decoder(nn.Module): clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img, is_latent_diffusion = is_latent_diffusion, - noise_scheduler = noise_scheduler + noise_scheduler = noise_scheduler, + timesteps = sample_timesteps ) img = vae.decode(img) diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 1bb7bfa..5f3685e 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -154,6 +154,7 @@ class DiffusionPriorConfig(BaseModel): image_size: int image_channels: int = 3 timesteps: int = 1000 + sample_timesteps: Optional[int] = None cond_drop_prob: float = 0. loss_type: str = 'l2' predict_x_start: bool = True @@ -233,6 +234,7 @@ class DecoderConfig(BaseModel): clip: Optional[AdapterConfig] # The clip model to use if embeddings are not provided channels: int = 3 timesteps: int = 1000 + sample_timesteps: Optional[SingularOrIterable(int)] = None loss_type: str = 'l2' beta_schedule: ListOrTuple(str) = 'cosine' learned_variance: bool = True diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 9202a02..a1c5b39 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -536,11 +536,19 @@ class DecoderTrainer(nn.Module): assert precision_type == torch.float, "DeepSpeed currently only supports float32 precision when using on the fly embedding generation from clip" clip = decoder.clip clip.to(precision_type) - decoder, train_loader, val_loader, *optimizers = list(self.accelerator.prepare(decoder, dataloaders["train"], dataloaders["val"], *optimizers)) + + decoder, *optimizers = list(self.accelerator.prepare(decoder, *optimizers)) + + self.decoder = decoder + + # prepare dataloaders + + train_loader = val_loader = None + if exists(dataloaders): + train_loader, val_loader = self.accelerator.prepare(dataloaders["train"], dataloaders["val"]) self.train_loader = train_loader self.val_loader = val_loader - self.decoder = decoder # store optimizers diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 5ec52a9..db7a416 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.18.0' +__version__ = '0.19.1'