From 8c823affff30e3a522ebbee6d5dc6c76d2490177 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 8 Jul 2022 11:44:43 -0700 Subject: [PATCH] allow for control over use of nearest interp method of downsampling low res conditioning, in addition to being able to turn it off --- dalle2_pytorch/dalle2_pytorch.py | 22 ++++++++++++++++++---- dalle2_pytorch/version.py | 2 +- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 738a09e..7576698 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -125,14 +125,23 @@ def log(t, eps = 1e-12): def l2norm(t): return F.normalize(t, dim = -1) -def resize_image_to(image, target_image_size, clamp_range = None): +def resize_image_to( + image, + target_image_size, + clamp_range = None, + nearest = False, + **kwargs +): orig_image_size = image.shape[-1] if orig_image_size == target_image_size: return image - scale_factors = target_image_size / orig_image_size - out = resize(image, scale_factors = scale_factors) + if not nearest: + scale_factors = target_image_size / orig_image_size + out = resize(image, scale_factors = scale_factors, **kwargs) + else: + out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False) if exists(clamp_range): out = out.clamp(*clamp_range) @@ -1781,12 +1790,15 @@ class LowresConditioner(nn.Module): def __init__( self, downsample_first = True, + downsample_mode_nearest = False, blur_sigma = 0.6, blur_kernel_size = 3, input_image_range = None ): super().__init__() self.downsample_first = downsample_first + self.downsample_mode_nearest = downsample_mode_nearest + self.input_image_range = input_image_range self.blur_sigma = blur_sigma @@ -1802,7 +1814,7 @@ class LowresConditioner(nn.Module): blur_kernel_size = None ): if self.training and self.downsample_first and exists(downsample_image_size): - cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range) + cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest) if self.training: # when training, blur the low resolution conditional image @@ -1845,6 +1857,7 @@ class Decoder(nn.Module): image_sizes = None, # for cascading ddpm, image size at each stage random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur + lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution blur_sigma = 0.6, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size clip_denoised = True, @@ -1987,6 +2000,7 @@ class Decoder(nn.Module): self.to_lowres_cond = LowresConditioner( downsample_first = lowres_downsample_first, + downsample_mode_nearest = lowres_downsample_mode_nearest, blur_sigma = blur_sigma, blur_kernel_size = blur_kernel_size, input_image_range = self.input_image_range diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index ada3c31..8e158f5 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.16.17' +__version__ = '0.16.18'