From aec5575d0911552ae705147ebcecbac383093fc5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 4 May 2022 19:26:45 -0700 Subject: [PATCH] take a bet on resize right, given Katherine is using it --- dalle2_pytorch/dalle2_pytorch.py | 24 ++++++++++-------------- setup.py | 3 ++- 2 files changed, 12 insertions(+), 15 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 2d99961..35df9d8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -21,6 +21,8 @@ import kornia.augmentation as K from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE +from resize_right import resize + # use x-clip from x_clip import CLIP @@ -86,14 +88,14 @@ def freeze_model_and_make_eval_(model): def l2norm(t): return F.normalize(t, dim = -1) -def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight - shape = cast_tuple(image_size, 2) - orig_image_size = t.shape[-2:] +def resize_image_to(image, target_image_size): + orig_image_size = image.shape[-1] - if orig_image_size == shape: - return t + if orig_image_size == target_image_size: + return image - return F.interpolate(t, size = shape, mode = mode, align_corners = False) + scale_factors = target_image_size / orig_image_size + return resize(image, scale_factors = scale_factors) # image normalization functions # ddpms expect images to be in the range of -1 to 1 @@ -1477,13 +1479,11 @@ class Unet(nn.Module): class LowresConditioner(nn.Module): def __init__( self, - cond_upsample_mode = 'bilinear', downsample_first = True, blur_sigma = 0.1, blur_kernel_size = 3, ): super().__init__() - self.cond_upsample_mode = cond_upsample_mode self.downsample_first = downsample_first self.blur_sigma = blur_sigma self.blur_kernel_size = blur_kernel_size @@ -1497,10 +1497,8 @@ class LowresConditioner(nn.Module): blur_sigma = None, blur_kernel_size = None ): - target_image_size = cast_tuple(target_image_size, 2) - if self.training and self.downsample_first and exists(downsample_image_size): - cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode) + cond_fmap = resize_image_to(cond_fmap, downsample_image_size) if self.training: # when training, blur the low resolution conditional image @@ -1508,7 +1506,7 @@ class LowresConditioner(nn.Module): blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size) cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2)) - cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode) + cond_fmap = resize_image_to(cond_fmap, target_image_size) return cond_fmap @@ -1528,7 +1526,6 @@ class Decoder(BaseGaussianDiffusion): predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) - lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur blur_sigma = 0.1, # cascading ddpm - blur sigma blur_kernel_size = 3, # cascading ddpm - blur kernel size @@ -1604,7 +1601,6 @@ class Decoder(BaseGaussianDiffusion): assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' self.to_lowres_cond = LowresConditioner( - cond_upsample_mode = lowres_cond_upsample_mode, downsample_first = lowres_downsample_first, blur_sigma = blur_sigma, blur_kernel_size = blur_kernel_size, diff --git a/setup.py b/setup.py index 5953e7a..dc5c744 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.99', + version = '0.0.100', license='MIT', description = 'DALL-E 2', author = 'Phil Wang', @@ -29,6 +29,7 @@ setup( 'embedding-reader', 'kornia>=0.5.4', 'pillow', + 'resize-right>=0.0.2', 'torch>=1.10', 'torchvision', 'tqdm',