diff --git a/README.md b/README.md index 23f3426..253cc8f 100644 --- a/README.md +++ b/README.md @@ -822,6 +822,7 @@ Once built, images will be saved to the same directory the command is invoked - [x] bring in tools to train vqgan-vae - [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet) - [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation) +- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias) - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo) - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs @@ -833,7 +834,6 @@ Once built, images will be saved to the same directory the command is invoked - [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824 - [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove - [ ] use an experimental tracker agnostic setup, as done here -- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias) - [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2 - [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783 - [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index dc35304..a9c853f 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -16,6 +16,7 @@ from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts.torch import EinopsToAndFrom from kornia.filters import gaussian_blur2d +import kornia.augmentation as K from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE @@ -1526,6 +1527,7 @@ class Decoder(BaseGaussianDiffusion): predict_x_start = False, predict_x_start_for_latent_diffusion = False, image_sizes = None, # for cascading ddpm, image size at each stage + random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops) lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur blur_sigma = 0.1, # cascading ddpm - blur sigma @@ -1588,6 +1590,10 @@ class Decoder(BaseGaussianDiffusion): self.image_sizes = image_sizes self.sample_channels = cast_tuple(self.channels, len(image_sizes)) + # random crop sizes (for super-resoluting unets at the end of cascade?) + + self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes)) + # predict x0 config self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes)) @@ -1777,10 +1783,10 @@ class Decoder(BaseGaussianDiffusion): unet = self.get_unet(unet_number) - target_image_size = self.image_sizes[unet_index] - vae = self.vaes[unet_index] - predict_x_start = self.predict_x_start[unet_index] - + vae = self.vaes[unet_index] + target_image_size = self.image_sizes[unet_index] + predict_x_start = self.predict_x_start[unet_index] + random_crop_size = self.random_crop_sizes[unet_index] b, c, h, w, device, = *image.shape, image.device check_shape(image, 'b c h w', c = self.channels) @@ -1801,6 +1807,14 @@ class Decoder(BaseGaussianDiffusion): lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None image = resize_image_to(image, target_image_size) + if exists(random_crop_size): + aug = K.RandomCrop((random_crop_size, random_crop_size)) + + # make sure low res conditioner and image both get augmented the same way + # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop + image = aug(image) + lowres_cond_img = aug(lowres_cond_img, params = aug._params) + vae.eval() with torch.no_grad(): image = vae.encode(image) diff --git a/setup.py b/setup.py index 5ee48dc..c8bb013 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.96', + version = '0.0.97', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',