From 138079ca83b35b245609623d5ba893d8e7981c43 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 20 Jun 2022 08:56:32 -0700 Subject: [PATCH] allow for setting beta schedules of unets differently in the decoder, as what was used in the paper was cosine, cosine, linear --- dalle2_pytorch/dalle2_pytorch.py | 130 +++++++++++++++++++------------ dalle2_pytorch/train_configs.py | 2 +- dalle2_pytorch/trainer.py | 8 +- dalle2_pytorch/version.py | 2 +- 4 files changed, 87 insertions(+), 55 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 26ba70a..7467e44 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -378,7 +378,7 @@ def sigmoid_beta_schedule(timesteps): return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start -class BaseGaussianDiffusion(nn.Module): +class NoiseScheduler(nn.Module): def __init__(self, *, beta_schedule, timesteps, loss_type, p2_loss_weight_gamma = 0., p2_loss_weight_k = 1): super().__init__() @@ -472,11 +472,10 @@ class BaseGaussianDiffusion(nn.Module): extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) - def sample(self, *args, **kwargs): - raise NotImplementedError - - def forward(self, *args, **kwargs): - raise NotImplementedError + def p2_reweigh_loss(self, loss, times): + if not self.has_p2_loss_reweighting: + return loss + return loss * extract(self.p2_loss_weight, times, loss.shape) # diffusion prior @@ -687,8 +686,7 @@ class Attention(nn.Module): # attention - sim = sim - sim.amax(dim = -1, keepdim = True).detach() - attn = sim.softmax(dim = -1) + attn = sim.softmax(dim = -1, dtype = torch.float32) attn = self.dropout(attn) # aggregate values @@ -862,7 +860,7 @@ class DiffusionPriorNetwork(nn.Module): return pred_image_embed -class DiffusionPrior(BaseGaussianDiffusion): +class DiffusionPrior(nn.Module): def __init__( self, net, @@ -883,7 +881,9 @@ class DiffusionPrior(BaseGaussianDiffusion): image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 clip_adapter_overrides = dict() ): - super().__init__( + super().__init__() + + self.noise_scheduler = NoiseScheduler( beta_schedule = beta_schedule, timesteps = timesteps, loss_type = loss_type @@ -923,6 +923,13 @@ class DiffusionPrior(BaseGaussianDiffusion): self.training_clamp_l2norm = training_clamp_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm + # device tracker + self.register_buffer('_dummy', torch.tensor([True]), persistent = False) + + @property + def device(self): + return self._dummy.device + def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' @@ -933,7 +940,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken else: - x_recon = self.predict_start_from_noise(x, t = t, noise = pred) + x_recon = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised and not self.predict_x_start: x_recon.clamp_(-1., 1.) @@ -941,7 +948,7 @@ class DiffusionPrior(BaseGaussianDiffusion): if self.predict_x_start and self.sampling_clamp_l2norm: x_recon = l2norm(x_recon) * self.image_embed_scale - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() @@ -955,7 +962,7 @@ class DiffusionPrior(BaseGaussianDiffusion): @torch.no_grad() def p_sample_loop(self, shape, text_cond, cond_scale = 1.): - device = self.betas.device + device = self.device b = shape[0] image_embed = torch.randn(shape, device=device) @@ -963,7 +970,7 @@ class DiffusionPrior(BaseGaussianDiffusion): if self.init_image_embed_l2norm: image_embed = l2norm(image_embed) * self.image_embed_scale - for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): + for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps): times = torch.full((b,), i, device = device, dtype = torch.long) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) @@ -972,7 +979,7 @@ class DiffusionPrior(BaseGaussianDiffusion): def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) - image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise) + image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise) pred = self.net( image_embed_noisy, @@ -986,7 +993,7 @@ class DiffusionPrior(BaseGaussianDiffusion): target = noise if not self.predict_x_start else image_embed - loss = self.loss_fn(pred, target) + loss = self.noise_scheduler.loss_fn(pred, target) return loss @torch.no_grad() @@ -997,7 +1004,7 @@ class DiffusionPrior(BaseGaussianDiffusion): img = torch.randn(shape, device = device) - for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = self.noise_scheduler.num_timesteps): img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond, cond_scale = cond_scale) return img @@ -1069,7 +1076,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # timestep conditioning from ddpm batch, device = image_embed.shape[0], image_embed.device - times = torch.randint(0, self.num_timesteps, (batch,), device = device, dtype = torch.long) + times = torch.randint(0, self.noise_scheduler.num_timesteps, (batch,), device = device, dtype = torch.long) # scale image embed (Katherine) @@ -1234,8 +1241,7 @@ class CrossAttention(nn.Module): mask = rearrange(mask, 'b j -> b 1 1 j') sim = sim.masked_fill(~mask, max_neg_value) - sim = sim - sim.amax(dim = -1, keepdim = True).detach() - attn = sim.softmax(dim = -1) + attn = sim.softmax(dim = -1, dtype = torch.float32) out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') @@ -1739,7 +1745,7 @@ class LowresConditioner(nn.Module): return cond_fmap -class Decoder(BaseGaussianDiffusion): +class Decoder(nn.Module): def __init__( self, unet, @@ -1752,7 +1758,7 @@ class Decoder(BaseGaussianDiffusion): image_cond_drop_prob = 0.1, text_cond_drop_prob = 0.5, loss_type = 'l2', - beta_schedule = 'cosine', + beta_schedule = None, predict_x_start = False, predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage @@ -1774,13 +1780,7 @@ class Decoder(BaseGaussianDiffusion): p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended p2_loss_weight_k = 1 ): - super().__init__( - beta_schedule = beta_schedule, - timesteps = timesteps, - loss_type = loss_type, - p2_loss_weight_gamma = p2_loss_weight_gamma, - p2_loss_weight_k = p2_loss_weight_k - ) + super().__init__() self.unconditional = unconditional @@ -1824,6 +1824,8 @@ class Decoder(BaseGaussianDiffusion): # while the rest of the unets are conditioned on the low resolution image produced by previous unet unets = cast_tuple(unet) + num_unets = len(unets) + vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper @@ -1859,6 +1861,24 @@ class Decoder(BaseGaussianDiffusion): self.unets.append(one_unet) self.vaes.append(one_vae.copy_for_eval()) + # create noise schedulers per unet + + if not exists(beta_schedule): + beta_schedule = ('cosine', *(('cosine',) * max(num_unets - 2, 0)), *(('linear',) * int(num_unets > 1))) + + self.noise_schedulers = nn.ModuleList([]) + + for unet_beta_schedule in beta_schedule: + noise_scheduler = NoiseScheduler( + beta_schedule = unet_beta_schedule, + timesteps = timesteps, + loss_type = loss_type, + p2_loss_weight_gamma = p2_loss_weight_gamma, + p2_loss_weight_k = p2_loss_weight_k + ) + + self.noise_schedulers.append(noise_scheduler) + # unet image sizes image_sizes = default(image_sizes, (image_size,)) @@ -1908,6 +1928,14 @@ class Decoder(BaseGaussianDiffusion): self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity + # device tracker + + self.register_buffer('_dummy', torch.Tensor([True]), persistent = False) + + @property + def device(self): + return self._dummy.device + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 @@ -1931,7 +1959,7 @@ class Decoder(BaseGaussianDiffusion): for unet, device in zip(self.unets, devices): unet.to(device) - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): + def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) @@ -1942,7 +1970,7 @@ class Decoder(BaseGaussianDiffusion): if predict_x_start: x_recon = pred else: - x_recon = self.predict_start_from_noise(x, t = t, noise = pred) + x_recon = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred) if clip_denoised: # s is the threshold amount @@ -1961,14 +1989,14 @@ class Decoder(BaseGaussianDiffusion): # clip by threshold, depending on whether static or dynamic x_recon = x_recon.clamp(-s, s) / s - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + model_mean, posterior_variance, posterior_log_variance = noise_scheduler.q_posterior(x_start=x_recon, x_t=x, t=t) if learned_variance: # if learned variance, posterio variance and posterior log variance are predicted by the network # by an interpolation of the max and min log beta values # eq 15 - https://arxiv.org/abs/2102.09672 - min_log = extract(self.posterior_log_variance_clipped, t, x.shape) - max_log = extract(torch.log(self.betas), t, x.shape) + min_log = extract(noise_scheduler.posterior_log_variance_clipped, t, x.shape) + max_log = extract(torch.log(noise_scheduler.betas), t, x.shape) var_interp_frac = unnormalize_zero_to_one(var_interp_frac_unnormalized) if self.learned_variance_constrain_frac: @@ -1980,17 +2008,17 @@ class Decoder(BaseGaussianDiffusion): return model_mean, posterior_variance, posterior_log_variance @torch.no_grad() - def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): + def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance) noise = torch.randn_like(x) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.no_grad() - def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): - device = self.betas.device + def p_sample_loop(self, unet, shape, image_embed, noise_scheduler, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1, is_latent_diffusion = False): + device = self.device b = shape[0] img = torch.randn(shape, device = device) @@ -1998,7 +2026,7 @@ class Decoder(BaseGaussianDiffusion): if not is_latent_diffusion: lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) - for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + for i in tqdm(reversed(range(0, noise_scheduler.num_timesteps)), desc = 'sampling loop time step', total = noise_scheduler.num_timesteps): img = self.p_sample( unet, img, @@ -2009,6 +2037,7 @@ class Decoder(BaseGaussianDiffusion): cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, + noise_scheduler = noise_scheduler, learned_variance = learned_variance, clip_denoised = clip_denoised ) @@ -2016,7 +2045,7 @@ class Decoder(BaseGaussianDiffusion): unnormalize_img = self.unnormalize_img(img) return unnormalize_img - def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): + def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): noise = default(noise, lambda: torch.randn_like(x_start)) # normalize to [-1, 1] @@ -2027,7 +2056,7 @@ class Decoder(BaseGaussianDiffusion): # get x_t - x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) + x_noisy = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise) model_output = unet( x_noisy, @@ -2047,11 +2076,10 @@ class Decoder(BaseGaussianDiffusion): target = noise if not predict_x_start else x_start - loss = self.loss_fn(pred, target, reduction = 'none') + loss = noise_scheduler.loss_fn(pred, target, reduction = 'none') loss = reduce(loss, 'b ... -> b (...)', 'mean') - if self.has_p2_loss_reweighting: - loss = loss * extract(self.p2_loss_weight, times, loss.shape) + loss = noise_scheduler.p2_reweigh_loss(loss, times) loss = loss.mean() @@ -2066,8 +2094,8 @@ class Decoder(BaseGaussianDiffusion): # if learning the variance, also include the extra weight kl loss - true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times) - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) + true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) # kl loss with detached model predicted mean, for stability reasons as in paper @@ -2117,7 +2145,7 @@ class Decoder(BaseGaussianDiffusion): img = None is_cuda = next(self.parameters()).is_cuda - for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)): + for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers)): context = self.one_unet_in_gpu(unet = unet) if is_cuda and not distributed else null_context() @@ -2145,7 +2173,8 @@ class Decoder(BaseGaussianDiffusion): learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img, - is_latent_diffusion = is_latent_diffusion + is_latent_diffusion = is_latent_diffusion, + noise_scheduler = noise_scheduler ) img = vae.decode(img) @@ -2171,6 +2200,7 @@ class Decoder(BaseGaussianDiffusion): unet = self.get_unet(unet_number) vae = self.vaes[unet_index] + noise_scheduler = self.noise_schedulers[unet_index] target_image_size = self.image_sizes[unet_index] predict_x_start = self.predict_x_start[unet_index] random_crop_size = self.random_crop_sizes[unet_index] @@ -2180,7 +2210,7 @@ class Decoder(BaseGaussianDiffusion): check_shape(image, 'b c h w', c = self.channels) assert h >= target_image_size and w >= target_image_size - times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) + times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long) if not exists(image_embed) and not self.unconditional: assert exists(self.clip), 'if you want to derive CLIP image embeddings automatically, you must supply `clip` to the decoder on init' @@ -2211,7 +2241,7 @@ class Decoder(BaseGaussianDiffusion): image = vae.encode(image) lowres_cond_img = maybe(vae.encode)(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion) + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler) # main class diff --git a/dalle2_pytorch/train_configs.py b/dalle2_pytorch/train_configs.py index 56713d6..80bc0cd 100644 --- a/dalle2_pytorch/train_configs.py +++ b/dalle2_pytorch/train_configs.py @@ -173,7 +173,7 @@ class DecoderConfig(BaseModel): channels: int = 3 timesteps: int = 1000 loss_type: str = 'l2' - beta_schedule: str = 'cosine' + beta_schedule: ListOrTuple(str) = 'cosine' learned_variance: bool = True image_cond_drop_prob: float = 0.1 text_cond_drop_prob: float = 0.5 diff --git a/dalle2_pytorch/trainer.py b/dalle2_pytorch/trainer.py index 7008bf7..bc68caa 100644 --- a/dalle2_pytorch/trainer.py +++ b/dalle2_pytorch/trainer.py @@ -24,7 +24,9 @@ def exists(val): return val is not None def default(val, d): - return val if exists(val) else d + if exists(val): + return val + return d() if callable(d) else d def cast_tuple(val, length = 1): return val if isinstance(val, tuple) else ((val,) * length) @@ -574,8 +576,8 @@ def decoder_sample_in_chunks(fn): class DecoderTrainer(nn.Module): def __init__( self, - accelerator, decoder, + accelerator = None, use_ema = True, lr = 1e-4, wd = 1e-2, @@ -589,7 +591,7 @@ class DecoderTrainer(nn.Module): assert isinstance(decoder, Decoder) ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) - self.accelerator = accelerator + self.accelerator = default(accelerator, Accelerator) self.num_unets = len(decoder.unets) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index e754a83..f323a57 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.10.1' +__version__ = '0.11.0'