From 587c8c9b446634efce561c183d414f9630ada5b6 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 28 Apr 2022 21:59:13 -0700 Subject: [PATCH] optimize for clarity --- dalle2_pytorch/dalle2_pytorch.py | 49 ++++++++++++++------------------ 1 file changed, 22 insertions(+), 27 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 58920fa..af37afa 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -223,7 +223,18 @@ class BaseGaussianDiffusion(nn.Module): timesteps, = betas.shape self.num_timesteps = int(timesteps) + + if loss_type == 'l1': + loss_fn = F.l1_loss + elif loss_type == 'l2': + loss_fn = F.mse_loss + elif loss_type == 'huber': + loss_fn = F.smooth_l1_loss + else: + raise NotImplementedError() + self.loss_type = loss_type + self.loss_fn = loss_fn self.register_buffer('betas', betas) self.register_buffer('alphas_cumprod', alphas_cumprod) @@ -703,29 +714,21 @@ class DiffusionPrior(BaseGaussianDiffusion): img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) return img - def p_losses(self, image_embed, t, text_cond, noise = None): + def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) - image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise) + image_embed_noisy = self.q_sample(x_start = image_embed, t = times, noise = noise) - x_recon = self.net( + pred = self.net( image_embed_noisy, - t, + times, cond_drop_prob = self.cond_drop_prob, **text_cond ) - to_predict = noise if not self.predict_x_start else image_embed - - if self.loss_type == 'l1': - loss = F.l1_loss(to_predict, x_recon) - elif self.loss_type == 'l2': - loss = F.mse_loss(to_predict, x_recon) - elif self.loss_type == "huber": - loss = F.smooth_l1_loss(to_predict, x_recon) - else: - raise NotImplementedError() + target = noise if not self.predict_x_start else image_embed + loss = self.loss_fn(pred, target) return loss @torch.no_grad() @@ -1388,14 +1391,14 @@ class Decoder(BaseGaussianDiffusion): return img - def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): + def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): noise = default(noise, lambda: torch.randn_like(x_start)) - x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) + x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise) - x_recon = unet( + pred = unet( x_noisy, - t, + times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, @@ -1404,15 +1407,7 @@ class Decoder(BaseGaussianDiffusion): target = noise if not predict_x_start else x_start - if self.loss_type == 'l1': - loss = F.l1_loss(target, x_recon) - elif self.loss_type == 'l2': - loss = F.mse_loss(target, x_recon) - elif self.loss_type == "huber": - loss = F.smooth_l1_loss(target, x_recon) - else: - raise NotImplementedError() - + loss = self.loss_fn(pred, target) return loss @torch.no_grad()