diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 218d931..f35e9b3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1946,6 +1946,7 @@ class LowresConditioner(nn.Module): self, downsample_first = True, downsample_mode_nearest = False, + blur_prob = 0.5, blur_sigma = 0.6, blur_kernel_size = 3, input_image_range = None @@ -1956,6 +1957,7 @@ class LowresConditioner(nn.Module): self.input_image_range = input_image_range + self.blur_prob = blur_prob self.blur_sigma = blur_sigma self.blur_kernel_size = blur_kernel_size @@ -1968,20 +1970,27 @@ class LowresConditioner(nn.Module): blur_sigma = None, blur_kernel_size = None ): - if self.training and self.downsample_first and exists(downsample_image_size): + if self.downsample_first and exists(downsample_image_size): cond_fmap = resize_image_to(cond_fmap, downsample_image_size, clamp_range = self.input_image_range, nearest = self.downsample_mode_nearest) - if self.training: + # blur is only applied 50% of the time + # section 3.1 in https://arxiv.org/abs/2106.15282 + + if random.random() < self.blur_prob: + # when training, blur the low resolution conditional image + blur_sigma = default(blur_sigma, self.blur_sigma) blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) # allow for drawing a random sigma between lo and hi float values + if isinstance(blur_sigma, tuple): blur_sigma = tuple(map(float, blur_sigma)) blur_sigma = random.uniform(*blur_sigma) # allow for drawing a random kernel size between lo and hi int values + if isinstance(blur_kernel_size, tuple): blur_kernel_size = tuple(map(int, blur_kernel_size)) kernel_size_lo, kernel_size_hi = blur_kernel_size @@ -1990,7 +1999,6 @@ class LowresConditioner(nn.Module): cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) cond_fmap = resize_image_to(cond_fmap, target_image_size, clamp_range = self.input_image_range) - return cond_fmap class Decoder(nn.Module): @@ -2014,6 +2022,7 @@ class Decoder(nn.Module): random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_mode_nearest = False, # cascading ddpm - whether to use nearest mode downsampling for lower resolution + blur_prob = 0.5, # cascading ddpm - when training, the gaussian blur is only applied 50% of the time blur_sigma = 0.6, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size clip_denoised = True, @@ -2162,9 +2171,12 @@ class Decoder(nn.Module): lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' + self.lowres_downsample_mode_nearest = lowres_downsample_mode_nearest + self.to_lowres_cond = LowresConditioner( downsample_first = lowres_downsample_first, downsample_mode_nearest = lowres_downsample_mode_nearest, + blur_prob = blur_prob, blur_sigma = blur_sigma, blur_kernel_size = blur_kernel_size, input_image_range = self.input_image_range @@ -2483,7 +2495,7 @@ class Decoder(nn.Module): shape = (batch_size, channel, image_size, image_size) if unet.lowres_cond: - lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) + lowres_cond_img = resize_image_to(img, target_image_size = image_size, clamp_range = self.input_image_range, nearest = self.lowres_downsample_mode_nearest) is_latent_diffusion = isinstance(vae, VQGanVAE) image_size = vae.get_encoded_fmap_size(image_size) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index 4946c77..a123ffa 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.23.4' +__version__ = '0.23.5'