diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 94f129f..f7c3b1c 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1279,9 +1279,12 @@ class DiffusionPrior(nn.Module): is_ddim = timesteps < self.noise_scheduler.num_timesteps if not is_ddim: - return self.p_sample_loop_ddpm(*args, **kwargs) + normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs) + else: + normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) - return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) + image_embed = normalized_image_embed / self.image_embed_scale + return image_embed def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) @@ -1350,8 +1353,6 @@ class DiffusionPrior(nn.Module): # retrieve original unscaled image embed - image_embeds /= self.image_embed_scale - text_embeds = text_cond['text_embed'] text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index d07785c..f3df7f0 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.4' +__version__ = '1.6.5'