mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective
This commit is contained in:
@@ -805,6 +805,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
sampling_clamp_l2norm = False,
|
sampling_clamp_l2norm = False,
|
||||||
|
training_clamp_l2norm = False,
|
||||||
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||||
clip_adapter_overrides = dict()
|
clip_adapter_overrides = dict()
|
||||||
):
|
):
|
||||||
@@ -842,6 +843,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
# whether to force an l2norm, similar to clipping denoised, when sampling
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
|
self.training_clamp_l2norm = training_clamp_l2norm
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
@@ -894,6 +896,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
**text_cond
|
**text_cond
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.predict_x_start and self.training_clamp_l2norm:
|
||||||
|
pred = l2norm(pred) * self.image_embed_scale
|
||||||
|
|
||||||
target = noise if not self.predict_x_start else image_embed
|
target = noise if not self.predict_x_start else image_embed
|
||||||
|
|
||||||
loss = self.loss_fn(pred, target)
|
loss = self.loss_fn(pred, target)
|
||||||
|
|||||||
Reference in New Issue
Block a user