diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index c9cf125..1f2079b 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -807,6 +807,7 @@ class DiffusionPrior(BaseGaussianDiffusion): condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training sampling_clamp_l2norm = False, training_clamp_l2norm = False, + init_image_embed_l2norm = False, image_embed_scale = None, # this is for scaling the l2-normed image embedding, so it is more suitable for gaussian diffusion, as outlined by Katherine (@crowsonkb) https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132 clip_adapter_overrides = dict() ): @@ -845,6 +846,7 @@ class DiffusionPrior(BaseGaussianDiffusion): # whether to force an l2norm, similar to clipping denoised, when sampling self.sampling_clamp_l2norm = sampling_clamp_l2norm self.training_clamp_l2norm = training_clamp_l2norm + self.init_image_embed_l2norm = init_image_embed_l2norm def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): pred = self.net(x, t, **text_cond) @@ -879,11 +881,16 @@ class DiffusionPrior(BaseGaussianDiffusion): device = self.betas.device b = shape[0] - img = torch.randn(shape, device=device) + image_embed = torch.randn(shape, device=device) + + if self.init_image_embed_l2norm: + image_embed = l2norm(image_embed) * self.image_embed_scale for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): - img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond) - return img + times = torch.full((b,), i, device = device, dtype = torch.long) + image_embed = self.p_sample(image_embed, times, text_cond = text_cond) + + return image_embed def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) diff --git a/setup.py b/setup.py index 6d86262..410d7f5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.1.6', + version = '0.1.7', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',