diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 1a9b3e2..769e6c5 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -146,7 +146,7 @@ def resize_image_to( scale_factors = target_image_size / orig_image_size out = resize(image, scale_factors = scale_factors, **kwargs) else: - out = F.interpolate(image, target_image_size, mode = 'nearest', align_corners = False) + out = F.interpolate(image, target_image_size, mode = 'nearest') if exists(clamp_range): out = out.clamp(*clamp_range) @@ -1957,7 +1957,6 @@ class LowresConditioner(nn.Module): def __init__( self, downsample_first = True, - downsample_mode_nearest = False, blur_prob = 0.5, blur_sigma = 0.6, blur_kernel_size = 3, @@ -1965,8 +1964,6 @@ class LowresConditioner(nn.Module): ): super().__init__() self.downsample_first = downsample_first - self.downsample_mode_nearest = downsample_mode_nearest - self.input_image_range = input_image_range self.blur_prob = blur_prob @@ -1983,7 +1980,7 @@ class LowresConditioner(nn.Module): blur_kernel_size = None ): if self.downsample_first and exists(downsample_image_size): - cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest) + cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = True) # blur is only applied 50% of the time # section 3.1 in https://arxiv.org/abs/2106.15282 @@ -2010,7 +2007,7 @@ class LowresConditioner(nn.Module): cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) - cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range) + cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range, nearest = True) return cond_fmap class Decoder(nn.Module): @@ -2033,7 +2030,6 @@ class Decoder(nn.Module): image_sizes = None, # for cascading ddpm, image size at each stage random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur - lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time blur_sigma = 0.6, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size @@ -2183,11 +2179,8 @@ class Decoder(nn.Module): lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' - self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest - self.to_lowres_cond = LowresConditioner( downsample_first = lowres_downsample_first, - downsample_mode_nearest = lowres_downsample_mode_nearest, blur_prob = blur_prob, blur_sigma = blur_sigma, blur_kernel_size = blur_kernel_size, @@ -2510,7 +2503,7 @@ class Decoder(nn.Module): shape = (batch_size, channel, image_size, image_size) if unet.lowres_cond: - lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest) + lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = True) is_latent_diffusion = isinstance(vae, VQGanVAE) image_size = vae.get_encoded_fmap_size(image_size) @@ -2580,7 +2573,7 @@ class Decoder(nn.Module): assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented' lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None - image = resize_image_to(image, target_image_size) + image = resize_image_to(image, target_image_size, nearest = True) if exists(random_crop_size): aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index d5d3ec6..ce97d1d 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.9' +__version__ = '0.23.10'