From 46be8c32d3feb8b59351a67cbd73d80027fcbb4f Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 7 Jul 2022 09:41:49 -0700 Subject: [PATCH] fix a potential issue in the low resolution conditioner, when downsampling and then upsampling using resize right, thanks to @marunine --- README.md | 1 + dalle2_pytorch/dalle2_pytorch.py | 21 +++++++++++++++++---- dalle2_pytorch/version.py | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index f205024..aabb062 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ This library would not have gotten to this working state without the help of - Kumar for working on the initial diffusion training script - Romain for the pull request reviews and project management - He Cao and xiankgx for the Q&A and for identifying of critical bugs +- Marunine for identifying issues with resizing of the low resolution conditioner, when training the upsampler, in addition to various other bug fixes - Katherine for her advice - Stability AI for the generous sponsorship - 🤗 Huggingface and in particular Sylvain for the Accelerate library diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 8598fa4..738a09e 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -125,14 +125,19 @@ def log(t, eps = 1e-12): def l2norm(t): return F.normalize(t, dim = -1) -def resize_image_to(image, target_image_size): +def resize_image_to(image, target_image_size, clamp_range = None): orig_image_size = image.shape[-1] if orig_image_size == target_image_size: return image scale_factors = target_image_size / orig_image_size - return resize(image, scale_factors = scale_factors) + out = resize(image, scale_factors = scale_factors) + + if exists(clamp_range): + out = out.clamp(*clamp_range) + + return out # image normalization functions # ddpms expect images to be in the range of -1 to 1 @@ -1778,9 +1783,12 @@ class LowresConditioner(nn.Module): downsample_first = True, blur_sigma = 0.6, blur_kernel_size = 3, + input_image_range = None ): super().__init__() self.downsample_first = downsample_first + self.input_image_range = input_image_range + self.blur_sigma = blur_sigma self.blur_kernel_size = blur_kernel_size @@ -1794,7 +1802,7 @@ class LowresConditioner(nn.Module): blur_kernel_size = None ): if self.training and self.downsample_first and exists(downsample_image_size): - cond_fmap = resize_image_to(cond_fmap, downsample_image_size) + cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range) if self.training: # when training, blur the low resolution conditional image @@ -1814,7 +1822,7 @@ class LowresConditioner(nn.Module): cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) - cond_fmap = resize_image_to(cond_fmap, target_image_size) + cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range) return cond_fmap @@ -1968,6 +1976,10 @@ class Decoder(nn.Module): self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) + # input image range + + self.input_image_range = (-1. if not auto_normalize_img else 0., 1.) + # cascading ddpm related stuff lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) @@ -1977,6 +1989,7 @@ class Decoder(nn.Module): downsample_first = lowres_downsample_first, blur_sigma = blur_sigma, blur_kernel_size = blur_kernel_size, + input_image_range = self.input_image_range ) # classifier free guidance diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 17136cb..b086370 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.15' +__version__ = '0.16.16'