also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective

This commit is contained in:
Phil Wang
2022-05-06 10:05:14 -07:00
parent 09e9eaa5a6
commit 14e63a3f67
2 changed files with 6 additions and 1 deletions

View File

@@ -805,6 +805,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
sampling_clamp_l2norm = False, sampling_clamp_l2norm = False,
training_clamp_l2norm = False,
image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
clip_adapter_overrides = dict() clip_adapter_overrides = dict()
): ):
@@ -842,6 +843,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
# whether to force an l2norm, similar to clipping denoised, when sampling # whether to force an l2norm, similar to clipping denoised, when sampling
self.sampling_clamp_l2norm = sampling_clamp_l2norm self.sampling_clamp_l2norm = sampling_clamp_l2norm
self.training_clamp_l2norm = training_clamp_l2norm
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond) pred = self.net(x, t, **text_cond)
@@ -894,6 +896,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
**text_cond **text_cond
) )
if self.predict_x_start and self.training_clamp_l2norm:
pred = l2norm(pred) * self.image_embed_scale
target = noise if not self.predict_x_start else image_embed target = noise if not self.predict_x_start else image_embed
loss = self.loss_fn(pred, target) loss = self.loss_fn(pred, target)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.1.1', version = '0.1.2',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',