allow for last unet in the cascade to be trained on crops, if it is convolution-only

This commit is contained in:
Phil Wang
2022-05-04 11:48:41 -07:00
parent 74103fd8d6
commit 97b751209f
3 changed files with 20 additions and 6 deletions

View File

@@ -822,6 +822,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] bring in tools to train vqgan-vae
- [x] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
- [x] make sure DDPMs can be run with traditional resnet blocks (but leave convnext as an option for experimentation)
- [x] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet (test out unet² in ddpm repo)
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
@@ -833,7 +834,6 @@ Once built, images will be saved to the same directory the command is invoked
- [ ] figure out if possible to augment with external memory, as described in https://arxiv.org/abs/2204.11824
- [ ] test out grid attention in cascading ddpm locally, decide whether to keep or remove
- [ ] use an experimental tracker agnostic setup, as done <a href="https://github.com/lucidrains/tf-bind-transformer#simple-trainer-class-for-fine-tuning">here</a>
- [ ] make sure for the latter unets in the cascade, one can train on crops for learning super resolution (constrain the unet to be only convolutions in that case, or allow conv-like attention with rel pos bias)
- [ ] interface out the vqgan-vae so a pretrained one can be pulled off the shelf to validate latent diffusion + DALL-E2
- [ ] make sure FILIP works with DALL-E2 from x-clip https://arxiv.org/abs/2111.07783
- [ ] make sure resnet | convnext block hyperparameters can be configurable across unet depth (groups and expansion factor)

View File

@@ -16,6 +16,7 @@ from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d
import kornia.augmentation as K
from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
@@ -1526,6 +1527,7 @@ class Decoder(BaseGaussianDiffusion):
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
blur_sigma = 0.1, # cascading ddpm - blur sigma
@@ -1588,6 +1590,10 @@ class Decoder(BaseGaussianDiffusion):
self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# random crop sizes (for super-resoluting unets at the end of cascade?)
self.random_crop_sizes = cast_tuple(random_crop_sizes, len(image_sizes))
# predict x0 config
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
@@ -1777,10 +1783,10 @@ class Decoder(BaseGaussianDiffusion):
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
target_image_size = self.image_sizes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
random_crop_size = self.random_crop_sizes[unet_index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
@@ -1801,6 +1807,14 @@ class Decoder(BaseGaussianDiffusion):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size)
if exists(random_crop_size):
aug = K.RandomCrop((random_crop_size, random_crop_size))
# make sure low res conditioner and image both get augmented the same way
# detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
image = aug(image)
lowres_cond_img = aug(lowres_cond_img, params = aug._params)
vae.eval()
with torch.no_grad():
image = vae.encode(image)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.96',
version = '0.0.97',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',