From bfbcc283a3b06a7069948692b2bb6c8d1403021a Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 26 Apr 2022 11:39:12 -0700 Subject: [PATCH] DRY a tiny bit for gaussian diffusion related logic --- README.md | 3 +- dalle2_pytorch/dalle2_pytorch.py | 257 ++++++++++++------------------- 2 files changed, 101 insertions(+), 159 deletions(-) diff --git a/README.md b/README.md index 806ee40..73fc4c2 100644 --- a/README.md +++ b/README.md @@ -643,7 +643,8 @@ Once built, images will be saved to the same directory the command is invoked - [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms - [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0 - [x] use attention-based upsampling https://arxiv.org/abs/2112.11435 -- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - use inheritance just this once for sharing logic between decoder and prior network ddpms +- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms +- [ ] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8dd338b..612dd0f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -143,6 +143,92 @@ def sigmoid_beta_schedule(timesteps): return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start +class BaseGaussianDiffusion(nn.Module): + def __init__(self, *, beta_schedule, timesteps, loss_type): + super().__init__() + + if beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + elif beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "quadratic": + betas = quadratic_beta_schedule(timesteps) + elif beta_schedule == "jsd": + betas = 1.0 / torch.linspace(timesteps, 1, timesteps) + elif beta_schedule == "sigmoid": + betas = sigmoid_beta_schedule(timesteps) + else: + raise NotImplementedError() + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis = 0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + self.register_buffer('betas', betas) + self.register_buffer('alphas_cumprod', alphas_cumprod) + self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + self.register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def sample(self, *args, **kwargs): + raise NotImplementedError + + def forward(self, *args, **kwargs): + raise NotImplementedError + # diffusion prior class LayerNorm(nn.Module): @@ -481,7 +567,7 @@ class DiffusionPriorNetwork(nn.Module): return pred_image_embed -class DiffusionPrior(nn.Module): +class DiffusionPrior(BaseGaussianDiffusion): def __init__( self, net, @@ -497,7 +583,11 @@ class DiffusionPrior(nn.Module): beta_schedule = "cosine", condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training ): - super().__init__() + super().__init__( + beta_schedule = beta_schedule, + timesteps = timesteps, + loss_type = loss_type + ) if exists(clip): assert isinstance(clip, CLIP) @@ -517,53 +607,6 @@ class DiffusionPrior(nn.Module): self.predict_x_start = predict_x_start # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. - if beta_schedule == "cosine": - betas = cosine_beta_schedule(timesteps) - elif beta_schedule == "linear": - betas = linear_beta_schedule(timesteps) - elif beta_schedule == "quadratic": - betas = quadratic_beta_schedule(timesteps) - elif beta_schedule == "jsd": - betas = 1.0 / torch.linspace(timesteps, 1, timesteps) - elif beta_schedule == "sigmoid": - betas = sigmoid_beta_schedule(timesteps) - else: - raise NotImplementedError() - - alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, axis = 0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - self.register_buffer('betas', betas) - self.register_buffer('alphas_cumprod', alphas_cumprod) - self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) - self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) - self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) - self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) - self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - self.register_buffer('posterior_variance', posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) - self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) - self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) - @torch.no_grad() def get_image_embed(self, image): assert exists(self.clip) @@ -587,27 +630,6 @@ class DiffusionPrior(nn.Module): return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) - def q_mean_variance(self, x_start, t): - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1. - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) @@ -644,14 +666,6 @@ class DiffusionPrior(nn.Module): img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) return img - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - def p_losses(self, image_embed, t, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) @@ -1164,7 +1178,7 @@ class LowresConditioner(nn.Module): return cond_fmap -class Decoder(nn.Module): +class Decoder(BaseGaussianDiffusion): def __init__( self, unet, @@ -1184,7 +1198,12 @@ class Decoder(nn.Module): blur_kernel_size = 3, # cascading ddpm - blur kernel size condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation ): - super().__init__() + super().__init__( + beta_schedule = beta_schedule, + timesteps = timesteps, + loss_type = loss_type + ) + assert isinstance(clip, CLIP) freeze_model_and_make_eval_(clip) self.clip = clip @@ -1248,55 +1267,6 @@ class Decoder(nn.Module): self.cond_drop_prob = cond_drop_prob - # noise schedule - - if beta_schedule == "cosine": - betas = cosine_beta_schedule(timesteps) - elif beta_schedule == "linear": - betas = linear_beta_schedule(timesteps) - elif beta_schedule == "quadratic": - betas = quadratic_beta_schedule(timesteps) - elif beta_schedule == "jsd": - betas = 1.0 / torch.linspace(timesteps, 1, timesteps) - elif beta_schedule == "sigmoid": - betas = sigmoid_beta_schedule(timesteps) - else: - raise NotImplementedError() - - alphas = 1. - betas - alphas_cumprod = torch.cumprod(alphas, axis = 0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) - - timesteps, = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - self.register_buffer('betas', betas) - self.register_buffer('alphas_cumprod', alphas_cumprod) - self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) - self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) - self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) - self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) - self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - self.register_buffer('posterior_variance', posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) - self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) - self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) - def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 @@ -1329,27 +1299,6 @@ class Decoder(nn.Module): image_embed = self.clip.to_visual_latent(image_cls) return l2norm(image_embed) - def q_mean_variance(self, x_start, t): - mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - variance = extract(1. - self.alphas_cumprod, t, x_start.shape) - log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) - return mean, variance, log_variance - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + - extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) @@ -1394,14 +1343,6 @@ class Decoder(nn.Module): return img - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + - extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): noise = default(noise, lambda: torch.randn_like(x_start))