diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index bf1360d..bb1da34 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -922,11 +922,12 @@ class DiffusionPrior(nn.Module): loss_type = "l2", predict_x_start = True, beta_schedule = "cosine", - condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training - sampling_clamp_l2norm = False, + condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training + sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs) + sampling_final_clamp_l2norm = False, # whether to l2norm the final image embedding output (this is also done for images in ddpm) training_clamp_l2norm = False, init_image_embed_l2norm = False, - image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 + image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 clip_adapter_overrides = dict() ): super().__init__() @@ -963,23 +964,32 @@ class DiffusionPrior(nn.Module): self.condition_on_text_encodings = condition_on_text_encodings # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. + self.predict_x_start = predict_x_start # @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 + self.image_embed_scale = default(image_embed_scale, self.image_embed_dim ** 0.5) # whether to force an l2norm, similar to clipping denoised, when sampling + self.sampling_clamp_l2norm = sampling_clamp_l2norm + self.sampling_final_clamp_l2norm = sampling_final_clamp_l2norm + self.training_clamp_l2norm = training_clamp_l2norm self.init_image_embed_l2norm = init_image_embed_l2norm # device tracker + self.register_buffer('_dummy', torch.tensor([True]), persistent = False) @property def device(self): return self._dummy.device + def l2norm_clamp_embed(self, image_embed): + return l2norm(image_embed) * self.image_embed_scale + def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.): assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)' @@ -1020,6 +1030,9 @@ class DiffusionPrior(nn.Module): times = torch.full((batch,), i, device = device, dtype = torch.long) image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale) + if self.sampling_final_clamp_l2norm and self.predict_x_start: + image_embed = self.l2norm_clamp_embed(image_embed) + return image_embed @torch.no_grad() @@ -1055,7 +1068,7 @@ class DiffusionPrior(nn.Module): x_start.clamp_(-1., 1.) if self.predict_x_start and self.sampling_clamp_l2norm: - x_start = l2norm(x_start) * self.image_embed_scale + x_start = self.l2norm_clamp_embed(x_start) c1 = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() c2 = ((1 - alpha_next) - torch.square(c1)).sqrt() @@ -1065,6 +1078,9 @@ class DiffusionPrior(nn.Module): c1 * noise + \ c2 * pred_noise + if self.predict_x_start and self.sampling_final_clamp_l2norm: + image_embed = self.l2norm_clamp_embed(image_embed) + return image_embed @torch.no_grad() @@ -1091,7 +1107,7 @@ class DiffusionPrior(nn.Module): ) if self.predict_x_start and self.training_clamp_l2norm: - pred = l2norm(pred) * self.image_embed_scale + pred = self.l2norm_clamp_embed(pred) target = noise if not self.predict_x_start else image_embed diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index cd5705a..bd5657d 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.19.5' +__version__ = '0.19.6'