mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 02:34:19 +01:00
allow for predicting image embedding directly during diffusion training. need to fix sampling still
This commit is contained in:
@@ -277,7 +277,8 @@ class DiffusionPrior(nn.Module):
|
|||||||
clip,
|
clip,
|
||||||
timesteps = 1000,
|
timesteps = 1000,
|
||||||
cond_prob_drop = 0.2,
|
cond_prob_drop = 0.2,
|
||||||
loss_type = 'l1'
|
loss_type = 'l1',
|
||||||
|
predict_x0 = True
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(clip, CLIP)
|
assert isinstance(clip, CLIP)
|
||||||
@@ -289,6 +290,9 @@ class DiffusionPrior(nn.Module):
|
|||||||
self.image_size = clip.image_size
|
self.image_size = clip.image_size
|
||||||
self.cond_prob_drop = cond_prob_drop
|
self.cond_prob_drop = cond_prob_drop
|
||||||
|
|
||||||
|
self.predict_x0 = predict_x0
|
||||||
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
betas = cosine_beta_schedule(timesteps)
|
betas = cosine_beta_schedule(timesteps)
|
||||||
|
|
||||||
alphas = 1. - betas
|
alphas = 1. - betas
|
||||||
@@ -415,10 +419,12 @@ class DiffusionPrior(nn.Module):
|
|||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
to_predict = noise if not self.predict_x0 else image_embed
|
||||||
|
|
||||||
if self.loss_type == 'l1':
|
if self.loss_type == 'l1':
|
||||||
loss = F.l1_loss(noise, x_recon)
|
loss = F.l1_loss(to_predict, x_recon)
|
||||||
elif self.loss_type == 'l2':
|
elif self.loss_type == 'l2':
|
||||||
loss = F.mse_loss(noise, x_recon)
|
loss = F.mse_loss(to_predict, x_recon)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user