From 2277b47ffda6a0952bb40c3b326c0214705b831b Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 12 May 2022 14:18:15 -0700 Subject: [PATCH] make sure learned variance can work for any number of unets in the decoder, defaults to first unet, as suggested was used in the paper --- dalle2_pytorch/dalle2_pytorch.py | 147 +++++++++++++++++++++++++++---- setup.py | 2 +- 2 files changed, 130 insertions(+), 19 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index a442c73..8c63553 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -33,6 +33,10 @@ from rotary_embedding_torch import RotaryEmbedding from x_clip import CLIP from coca_pytorch import CoCa +# constants + +NAT = 1. / math.log(2.) + # helper functions def exists(val): @@ -91,6 +95,9 @@ def freeze_model_and_make_eval_(model): # tensor helpers +def log(t, eps = 1e-12): + return torch.log(t.clamp(min = eps)) + def l2norm(t): return F.normalize(t, dim = -1) @@ -297,6 +304,36 @@ def noise_like(shape, device, repeat=False): noise = lambda: torch.randn(shape, device=device) return repeat_noise() if repeat else noise() +def meanflat(x): + return x.mean(dim = tuple(range(1, len(x.shape)))) + +def normal_kl(mean1, logvar1, mean2, logvar2): + return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)) + +def approx_standard_normal_cdf(x): + return 0.5 * (1.0 + torch.tanh(((2.0 / math.pi) ** 0.5) * (x + 0.044715 * (x ** 3)))) + +def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999): + assert x.shape == means.shape == log_scales.shape + + centered_x = x - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1. / 255.) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1. / 255.) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = log(cdf_plus) + log_one_minus_cdf_min = log(1. - cdf_min) + cdf_delta = cdf_plus - cdf_min + + log_probs = torch.where(x < -thres, + log_cdf_plus, + torch.where(x > thres, + log_one_minus_cdf_min, + log(cdf_delta))) + + return log_probs + def cosine_beta_schedule(timesteps, s = 0.008): """ cosine schedule @@ -1266,6 +1303,7 @@ class Unet(nn.Module): out_dim = None, dim_mults=(1, 2, 4, 8), channels = 3, + channels_out = None, attn_dim_head = 32, attn_heads = 16, lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ @@ -1296,6 +1334,7 @@ class Unet(nn.Module): # determine dimensions self.channels = channels + self.channels_out = default(channels_out, channels) init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis init_dim = default(init_dim, dim // 3 * 2) @@ -1401,11 +1440,9 @@ class Unet(nn.Module): Upsample(dim_in) ])) - out_dim = default(out_dim, channels) - self.final_conv = nn.Sequential( ResnetBlock(dim, dim, groups = resnet_groups[0]), - nn.Conv2d(dim, out_dim, 1) + nn.Conv2d(dim, self.channels_out, 1) ) # if the current settings for the unet are not correct @@ -1415,13 +1452,25 @@ class Unet(nn.Module): *, lowres_cond, channels, + channels_out, cond_on_image_embeds, cond_on_text_encodings ): - if lowres_cond == self.lowres_cond and channels == self.channels and cond_on_image_embeds == self.cond_on_image_embeds and cond_on_text_encodings == self.cond_on_text_encodings: + if lowres_cond == self.lowres_cond and \ + channels == self.channels and \ + cond_on_image_embeds == self.cond_on_image_embeds and \ + cond_on_text_encodings == self.cond_on_text_encodings and \ + channels_out == self.channels_out: return self - updated_kwargs = {'lowres_cond': lowres_cond, 'channels': channels, 'cond_on_image_embeds': cond_on_image_embeds, 'cond_on_text_encodings': cond_on_text_encodings} + updated_kwargs = dict( + lowres_cond = lowres_cond, + channels = channels, + channels_out = channels_out, + cond_on_image_embeds = cond_on_image_embeds, + cond_on_text_encodings = cond_on_text_encodings + ) + return self.__class__(**{**self._locals, **updated_kwargs}) def forward_with_cond_scale( @@ -1621,6 +1670,8 @@ class Decoder(BaseGaussianDiffusion): clip_denoised = True, clip_x_start = True, clip_adapter_overrides = dict(), + learned_variance = True, + vb_loss_weight = 0.001, unconditional = False ): super().__init__( @@ -1659,10 +1710,18 @@ class Decoder(BaseGaussianDiffusion): unets = cast_tuple(unet) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) + # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper + + learned_variance = pad_tuple_to_length(cast_tuple(learned_variance), len(unets), fillvalue = False) + self.learned_variance = learned_variance + self.vb_loss_weight = vb_loss_weight + + # construct unets and vaes + self.unets = nn.ModuleList([]) self.vaes = nn.ModuleList([]) - for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): + for ind, (one_unet, one_vae, one_unet_learned_var) in enumerate(zip(unets, vaes, learned_variance)): assert isinstance(one_unet, Unet) assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE)) @@ -1670,12 +1729,14 @@ class Decoder(BaseGaussianDiffusion): latent_dim = one_vae.encoded_dim if exists(one_vae) else None unet_channels = default(latent_dim, self.channels) + unet_channels_out = unet_channels * (1 if not one_unet_learned_var else 2) one_unet = one_unet.cast_model_parameters( lowres_cond = not is_first, cond_on_image_embeds = is_first and not unconditional, cond_on_text_encodings = one_unet.cond_on_text_encodings and not unconditional, - channels = unet_channels + channels = unet_channels, + channels_out = unet_channels_out ) self.unets.append(one_unet) @@ -1738,8 +1799,11 @@ class Decoder(BaseGaussianDiffusion): yield unet.cpu() - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): - pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) + def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None): + pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)) + + if learned_variance: + pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1) if predict_x_start: x_recon = pred @@ -1750,19 +1814,31 @@ class Decoder(BaseGaussianDiffusion): x_recon.clamp_(-1., 1.) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + + if learned_variance: + # if learned variance, posterio variance and posterior log variance are predicted by the network + # by an interpolation of the max and min log beta values + # eq 15 - https://arxiv.org/abs/2102.09672 + min_log = extract(self.posterior_log_variance_clipped, t, x.shape) + max_log = extract(torch.log(self.betas), t, x.shape) + var_interp_frac = unnormalize_img(var_interp_frac_unnormalized) + + posterior_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log + posterior_variance = posterior_log_variance.exp() + return model_mean, posterior_variance, posterior_log_variance @torch.inference_mode() - def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): + def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, learned_variance = False, clip_denoised = True, repeat_noise = False): b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start, learned_variance = learned_variance) noise = noise_like(x.shape, device, repeat_noise) # no noise when t == 0 nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise @torch.inference_mode() - def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): + def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1): device = self.betas.device b = shape[0] @@ -1779,17 +1855,18 @@ class Decoder(BaseGaussianDiffusion): cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, + learned_variance = learned_variance, clip_denoised = clip_denoised ) return img - def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None): + def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False): noise = default(noise, lambda: torch.randn_like(x_start)) x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) - pred = unet( + model_output = unet( x_noisy, times, image_embed = image_embed, @@ -1800,10 +1877,43 @@ class Decoder(BaseGaussianDiffusion): text_cond_drop_prob = self.text_cond_drop_prob, ) + if learned_variance: + pred, _ = model_output.chunk(2, dim = 1) + else: + pred = model_output + target = noise if not predict_x_start else x_start loss = self.loss_fn(pred, target) - return loss + + if not learned_variance: + # return simple loss if not using learned variance + return loss + + # if learning the variance, also include the extra weight kl loss + + true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_noisy, t = times) + model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output) + + # kl loss with detached model predicted mean, for stability reasons as in paper + + detached_model_mean = model_mean.detach() + + kl = normal_kl(true_mean, true_log_variance_clipped, detached_model_mean, model_log_variance) + kl = meanflat(kl) * NAT + + decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = detached_model_mean, log_scales = 0.5 * model_log_variance) + decoder_nll = meanflat(decoder_nll) * NAT + + # at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + + vb_losses = torch.where(times == 0, decoder_nll, kl) + + # weight the vb loss smaller, for stability, as in the paper (recommended 0.001) + + vb_loss = vb_losses.mean() * self.vb_loss_weight + + return loss + vb_loss @torch.inference_mode() @eval_decorator @@ -1830,7 +1940,7 @@ class Decoder(BaseGaussianDiffusion): img = None is_cuda = next(self.parameters()).is_cuda - for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): + for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance)): context = self.one_unet_in_gpu(unet = unet) if is_cuda else null_context() @@ -1856,6 +1966,7 @@ class Decoder(BaseGaussianDiffusion): text_mask = text_mask, cond_scale = cond_scale, predict_x_start = predict_x_start, + learned_variance = learned_variance, clip_denoised = not is_latent_diffusion, lowres_cond_img = lowres_cond_img ) @@ -1885,6 +1996,7 @@ class Decoder(BaseGaussianDiffusion): target_image_size = self.image_sizes[unet_index] predict_x_start = self.predict_x_start[unet_index] random_crop_size = self.random_crop_sizes[unet_index] + learned_variance = self.learned_variance[unet_index] b, c, h, w, device, = *image.shape, image.device check_shape(image, 'b c h w', c = self.channels) @@ -1922,7 +2034,7 @@ class Decoder(BaseGaussianDiffusion): if exists(lowres_cond_img): lowres_cond_img = vae.encode(lowres_cond_img) - return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance) # main class @@ -1972,4 +2084,3 @@ class DALLE2(nn.Module): return images[0] return images - diff --git a/setup.py b/setup.py index 51d842e..49f7a19 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.2.12', + version = '0.2.14', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',