From 14e63a3f67674435a1a15b45e170c6a1146484d3 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 May 2022 10:05:14 -0700 Subject: [PATCH] also offer l2norm clamping in diffusion prior during training, if one were using predict x0 objective --- dalle2_pytorch/dalle2_pytorch.py | 5 +++++ setup.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index df3c8fb..7d78db6 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -805,6 +805,7 @@ class DiffusionPrior(BaseGaussianDiffusion): beta_schedule = "cosine", condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training sampling_clamp_l2norm = False, + training_clamp_l2norm = False, image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 clip_adapter_overrides = dict() ): @@ -842,6 +843,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # whether to force an l2norm, similar to clipping denoised, when sampling self.sampling_clamp_l2norm = sampling_clamp_l2norm + self.training_clamp_l2norm = training_clamp_l2norm def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) @@ -894,6 +896,9 @@ class DiffusionPrior(BaseGaussianDiffusion): **text_cond ) + if self.predict_x_start and self.training_clamp_l2norm: + pred = l2norm(pred) * self.image_embed_scale + target = noise if not self.predict_x_start else image_embed loss = self.loss_fn(pred, target) diff --git a/setup.py b/setup.py index 4674f7d..7528174 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.1.1', + version = '0.1.2', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',