From 1bf071af7814142259c495fb98abb4f5a10b29ba Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 13 Apr 2022 10:29:23 -0700 Subject: [PATCH] allow for predicting image embedding directly during diffusion training. need to fix sampling still --- dalle2_pytorch/dalle2_pytorch.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 45622a0..42be949 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -277,7 +277,8 @@ class DiffusionPrior(nn.Module): clip, timesteps = 1000, cond_prob_drop = 0.2, - loss_type = 'l1' + loss_type = 'l1', + predict_x0 = True ): super().__init__() assert isinstance(clip, CLIP) @@ -289,6 +290,9 @@ class DiffusionPrior(nn.Module): self.image_size = clip.image_size self.cond_prob_drop = cond_prob_drop + self.predict_x0 = predict_x0 + # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. + betas = cosine_beta_schedule(timesteps) alphas = 1. - betas @@ -415,10 +419,12 @@ class DiffusionPrior(nn.Module): **text_cond ) + to_predict = noise if not self.predict_x0 else image_embed + if self.loss_type == 'l1': - loss = F.l1_loss(noise, x_recon) + loss = F.l1_loss(to_predict, x_recon) elif self.loss_type == 'l2': - loss = F.mse_loss(noise, x_recon) + loss = F.mse_loss(to_predict, x_recon) else: raise NotImplementedError()