mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
take a bet on resize right, given Katherine is using it
This commit is contained in:
@@ -21,6 +21,8 @@ import kornia.augmentation as K
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
from resize_right import resize
|
||||
|
||||
# use x-clip
|
||||
|
||||
from x_clip import CLIP
|
||||
@@ -86,14 +88,14 @@ def freeze_model_and_make_eval_(model):
|
||||
def l2norm(t):
|
||||
return F.normalize(t, dim = -1)
|
||||
|
||||
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
|
||||
shape = cast_tuple(image_size, 2)
|
||||
orig_image_size = t.shape[-2:]
|
||||
def resize_image_to(image, target_image_size):
|
||||
orig_image_size = image.shape[-1]
|
||||
|
||||
if orig_image_size == shape:
|
||||
return t
|
||||
if orig_image_size == target_image_size:
|
||||
return image
|
||||
|
||||
return F.interpolate(t, size = shape, mode = mode, align_corners = False)
|
||||
scale_factors = target_image_size / orig_image_size
|
||||
return resize(image, scale_factors = scale_factors)
|
||||
|
||||
# image normalization functions
|
||||
# ddpms expect images to be in the range of -1 to 1
|
||||
@@ -1477,13 +1479,11 @@ class Unet(nn.Module):
|
||||
class LowresConditioner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cond_upsample_mode = 'bilinear',
|
||||
downsample_first = True,
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
):
|
||||
super().__init__()
|
||||
self.cond_upsample_mode = cond_upsample_mode
|
||||
self.downsample_first = downsample_first
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_kernel_size = blur_kernel_size
|
||||
@@ -1497,10 +1497,8 @@ class LowresConditioner(nn.Module):
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
target_image_size = cast_tuple(target_image_size, 2)
|
||||
|
||||
if self.training and self.downsample_first and exists(downsample_image_size):
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, downsample_image_size)
|
||||
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
@@ -1508,7 +1506,7 @@ class LowresConditioner(nn.Module):
|
||||
blur_kernel_size = default(blur_kernel_size, self.blur_kernel_size)
|
||||
cond_fmap = gaussian_blur2d(cond_fmap, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size, mode = self.cond_upsample_mode)
|
||||
cond_fmap = resize_image_to(cond_fmap, target_image_size)
|
||||
|
||||
return cond_fmap
|
||||
|
||||
@@ -1528,7 +1526,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||
@@ -1604,7 +1601,6 @@ class Decoder(BaseGaussianDiffusion):
|
||||
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
|
||||
|
||||
self.to_lowres_cond = LowresConditioner(
|
||||
cond_upsample_mode = lowres_cond_upsample_mode,
|
||||
downsample_first = lowres_downsample_first,
|
||||
blur_sigma = blur_sigma,
|
||||
blur_kernel_size = blur_kernel_size,
|
||||
|
||||
3
setup.py
3
setup.py
@@ -10,7 +10,7 @@ setup(
|
||||
'dream = dalle2_pytorch.cli:dream'
|
||||
],
|
||||
},
|
||||
version = '0.0.99',
|
||||
version = '0.0.100',
|
||||
license='MIT',
|
||||
description = 'DALL-E 2',
|
||||
author = 'Phil Wang',
|
||||
@@ -29,6 +29,7 @@ setup(
|
||||
'embedding-reader',
|
||||
'kornia>=0.5.4',
|
||||
'pillow',
|
||||
'resize-right>=0.0.2',
|
||||
'torch>=1.10',
|
||||
'torchvision',
|
||||
'tqdm',
|
||||
|
||||
Reference in New Issue
Block a user