From db0642c4cdcaf4639136c1978c8124235d75a52e Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 18 May 2022 20:22:52 -0700 Subject: [PATCH] quick fix for @marunine --- dalle2_pytorch/dalle2_pytorch.py | 15 ++++++++++----- setup.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index abf1072..91f8f42 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1697,7 +1697,8 @@ class Decoder(BaseGaussianDiffusion): clip_adapter_overrides = dict(), learned_variance = True, vb_loss_weight = 0.001, - unconditional = False + unconditional = False, + auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader ): super().__init__( beta_schedule = beta_schedule, @@ -1806,6 +1807,10 @@ class Decoder(BaseGaussianDiffusion): self.clip_denoised = clip_denoised self.clip_x_start = clip_x_start + # normalize and unnormalize image functions + self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity + self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity + def get_unet(self, unet_number): assert 0 < unet_number <= len(self.unets) index = unet_number - 1 @@ -1877,7 +1882,7 @@ class Decoder(BaseGaussianDiffusion): img = torch.randn(shape, device = device) if not is_latent_diffusion: - lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) + lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): img = self.p_sample( @@ -1894,7 +1899,7 @@ class Decoder(BaseGaussianDiffusion): clip_denoised = clip_denoised ) - unnormalize_img = unnormalize_zero_to_one(img) + unnormalize_img = self.unnormalize_img(img) return unnormalize_img def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False): @@ -1903,8 +1908,8 @@ class Decoder(BaseGaussianDiffusion): # normalize to [-1, 1] if not is_latent_diffusion: - x_start = normalize_neg_one_to_one(x_start) - lowres_cond_img = maybe(normalize_neg_one_to_one)(lowres_cond_img) + x_start = self.normalize_img(x_start) + lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) # get x_t diff --git a/setup.py b/setup.py index 59dc7ae..e3731a8 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.3.2', + version = '0.3.3', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',