mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow for last unet in the cascade to be trained on crops, if it is convolution-only
This commit is contained in:
@@ -16,6 +16,7 @@ from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters import gaussian_blur2d
|
||||
import kornia.augmentation as K
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
@@ -1526,6 +1527,7 @@ class Decoder(BaseGaussianDiffusion):
|
||||
predict_x_start = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
|
||||
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
|
||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||
@@ -1588,6 +1590,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
self.image_sizes = image_sizes
|
||||
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
|
||||
|
||||
# random crop sizes (for super-resoluting unets at the end of cascade?)
|
||||
|
||||
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
|
||||
|
||||
# predict x0 config
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
@@ -1777,10 +1783,10 @@ class Decoder(BaseGaussianDiffusion):
|
||||
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
|
||||
vae = self.vaes[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
check_shape(image, 'b c h w', c = self.channels)
|
||||
@@ -1801,6 +1807,14 @@ class Decoder(BaseGaussianDiffusion):
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
if exists(random_crop_size):
|
||||
aug = K.RandomCrop((random_crop_size, random_crop_size))
|
||||
|
||||
# make sure low res conditioner and image both get augmented the same way
|
||||
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
|
||||
image = aug(image)
|
||||
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
|
||||
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
Reference in New Issue
Block a user