allow for predicting image embedding directly during diffusion training. need to fix sampling still

This commit is contained in:
Phil Wang
2022-04-13 10:29:23 -07:00
parent 9f1fe6c7ae
commit 1bf071af78

View File

@@ -277,7 +277,8 @@ class DiffusionPrior(nn.Module):
clip, clip,
timesteps = 1000, timesteps = 1000,
cond_prob_drop = 0.2, cond_prob_drop = 0.2,
loss_type = 'l1' loss_type = 'l1',
predict_x0 = True
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
@@ -289,6 +290,9 @@ class DiffusionPrior(nn.Module):
self.image_size = clip.image_size self.image_size = clip.image_size
self.cond_prob_drop = cond_prob_drop self.cond_prob_drop = cond_prob_drop
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
betas = cosine_beta_schedule(timesteps) betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas alphas = 1. - betas
@@ -415,10 +419,12 @@ class DiffusionPrior(nn.Module):
**text_cond **text_cond
) )
to_predict = noise if not self.predict_x0 else image_embed
if self.loss_type == 'l1': if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon) loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon) loss = F.mse_loss(to_predict, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()