take a bet on resize right, given Katherine is using it

This commit is contained in:
Phil Wang
2022-05-04 19:26:45 -07:00
parent 9773f10d6c
commit aec5575d09
2 changed files with 12 additions and 15 deletions

View File

@@ -21,6 +21,8 @@ import kornia.augmentation as K
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
from resize_right import resize
# use x-clip
from x_clip import CLIP
@@ -86,14 +88,14 @@ def freeze_model_and_make_eval_(model):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
shape = cast_tuple(image_size, 2)
orig_image_size = t.shape[-2:]
def resize_image_to(image, target_image_size):
orig_image_size = image.shape[-1]
if orig_image_size == shape:
return t
if orig_image_size == target_image_size:
return image
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
scale_factors = target_image_size / orig_image_size
return resize(image, scale_factors = scale_factors)
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
@@ -1477,13 +1479,11 @@ class Unet(nn.Module):
class LowresConditioner(nn.Module):
def __init__(
self,
cond_upsample_mode = 'bilinear',
downsample_first = True,
blur_sigma = 0.1,
blur_kernel_size = 3,
):
super().__init__()
self.cond_upsample_mode = cond_upsample_mode
self.downsample_first = downsample_first
self.blur_sigma = blur_sigma
self.blur_kernel_size = blur_kernel_size
@@ -1497,10 +1497,8 @@ class LowresConditioner(nn.Module):
blur_sigma = None,
blur_kernel_size = None
):
target_image_size = cast_tuple(target_image_size, 2)
if self.training and self.downsample_first and exists(downsample_image_size):
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
if self.training:
# when training, blur the low resolution conditional image
@@ -1508,7 +1506,7 @@ class LowresConditioner(nn.Module):
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
cond_fmap = resize_image_to(cond_fmap, target_image_size)
return cond_fmap
@@ -1528,7 +1526,6 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
blur_kernel_size = 3, # cascading ddpm - blur kernel size
@@ -1604,7 +1601,6 @@ class Decoder(BaseGaussianDiffusion):
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.to_lowres_cond = LowresConditioner(
cond_upsample_mode = lowres_cond_upsample_mode,
downsample_first = lowres_downsample_first,
blur_sigma = blur_sigma,
blur_kernel_size = blur_kernel_size,