diff --git a/README.md b/README.md index b2fc789..b4344e6 100644 --- a/README.md +++ b/README.md @@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ This model is SOTA for text-to-image for now. -It may also explore an extension of using latent diffusion in the decoder from Rombach et al. - Please join Join us on Discord if you are interested in helping out with the replication There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. Placeholder repository. I will also eventually extend this to text to video, once the repository is in a good place. @@ -385,6 +383,117 @@ You can also train the decoder on images of greater than the size (say 512x512) For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. +## Experimental + +### DALL-E2 with Latent Diffusion + +This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al. + +You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish. + +```python +import torch +from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE + +# trained clip from step 1 + +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 1, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 1, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 +) + +# 2 unets for the decoder (a la cascading DDPM) + +# 1st unet is doing latent diffusion + +vae1 = VQGanVAE( + dim = 32, + image_size = 256, + layers = 3, + layer_mults = (1, 2, 4) +) + +vae2 = VQGanVAE( + dim = 32, + image_size = 512, + layers = 3, + layer_mults = (1, 2, 4) +) + +unet1 = Unet( + dim = 32, + image_embed_dim = 512, + cond_dim = 128, + channels = 3, + sparse_attn = True, + sparse_attn_window = 2, + dim_mults = (1, 2, 4, 8) +) + +unet2 = Unet( + dim = 32, + image_embed_dim = 512, + channels = 3, + dim_mults = (1, 2, 4, 8, 16), + cond_on_image_embeds = True, + cond_on_text_encodings = False +) + +unet3 = Unet( + dim = 32, + image_embed_dim = 512, + channels = 3, + dim_mults = (1, 2, 4, 8, 16), + cond_on_image_embeds = True, + cond_on_text_encodings = False, + attend_at_middle = False +) + +# decoder, which contains the unet(s) and clip + +decoder = Decoder( + clip = clip, + vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3 + unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here) + image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third + timesteps = 100, + cond_drop_prob = 0.2 +).cuda() + +# mock images (get a lot of this) + +images = torch.randn(1, 3, 1024, 1024).cuda() + +# feed images into decoder, specifying which unet you want to train +# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme + +with decoder.one_unet_in_gpu(1): + loss = decoder(images, unet_number = 1) + loss.backward() + +with decoder.one_unet_in_gpu(2): + loss = decoder(images, unet_number = 2) + loss.backward() + +# do the above for many steps for both unets + +# then it will learn to generate images based on the CLIP image embeddings + +# chaining the unets from lowest resolution to highest resolution (thus cascading) + +mock_image_embed = torch.randn(1, 512).cuda() +images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) +``` + ## CLI Usage (work in progress) ```bash @@ -412,11 +521,13 @@ Offer training wrappers - [x] add efficient attention in unet - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) -- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms +- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms +- [ ] spend one day cleaning up tech debt in decoder - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] train on a toy task, offer in colab - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference +- [ ] bring in tools to train vqgan-vae ## Citations diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 41cbab4..2b27df3 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -48,6 +48,12 @@ def is_list_str(x): return False return all([type(el) == str for el in x]) +def pad_tuple_to_length(t, length): + remain_length = length - len(t) + if remain_length <= 0: + return t + return (*t, *((None,) * remain_length)) + # for controlling freezing of CLIP def set_module_requires_grad_(module, requires_grad): @@ -540,12 +546,14 @@ class DiffusionPrior(nn.Module): self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + @torch.no_grad() def get_image_embed(self, image): image_encoding = self.clip.visual_transformer(image) image_cls = image_encoding[:, 0] image_embed = self.clip.to_visual_latent(image_cls) return l2norm(image_embed) + @torch.no_grad() def get_text_cond(self, text): text_encodings = self.clip.text_transformer(text) text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] @@ -940,11 +948,16 @@ class Unet(nn.Module): # if the current settings for the unet are not correct # for cascading DDPM, then reinit the unet with the right settings - def force_lowres_cond(self, lowres_cond): - if lowres_cond == self.lowres_cond: + def cast_model_parameters( + self, + *, + lowres_cond, + channels + ): + if lowres_cond == self.lowres_cond and channels == self.channels: return self - updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond} + updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels} return self.__class__(**updated_kwargs) def forward_with_cond_scale( @@ -1100,6 +1113,7 @@ class Decoder(nn.Module): unet, *, clip, + vae = None, timesteps = 1000, cond_drop_prob = 0.2, loss_type = 'l1', @@ -1120,11 +1134,25 @@ class Decoder(nn.Module): # automatically take care of ensuring that first unet is unconditional # while the rest of the unets are conditioned on the low resolution image produced by previous unet + unets = cast_tuple(unet) + vaes = pad_tuple_to_length(cast_tuple(vae), len(unets)) + self.unets = nn.ModuleList([]) - for ind, one_unet in enumerate(cast_tuple(unet)): + self.vaes = nn.ModuleList([]) + + for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): is_first = ind == 0 - one_unet = one_unet.force_lowres_cond(not is_first) + latent_dim = one_vae.encoded_dim if exists(one_vae) else None + + unet_channels = default(latent_dim, self.channels) + + one_unet = one_unet.cast_model_parameters( + lowres_cond = not is_first, + channels = unet_channels + ) + self.unets.append(one_unet) + self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None) # unet image sizes @@ -1219,10 +1247,12 @@ class Decoder(nn.Module): yield unet.cpu() + @torch.no_grad() def get_text_encodings(self, text): text_encodings = self.clip.text_transformer(text) return text_encodings[:, 1:] + @torch.no_grad() def get_image_embed(self, image): image = resize_image_to(image, self.clip_image_size) image_encoding = self.clip.visual_transformer(image) @@ -1324,25 +1354,43 @@ class Decoder(nn.Module): img = None - for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)): + for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)): with self.one_unet_in_gpu(unet = unet): - lowres_cond_img = self.to_lowres_cond( - img, - target_image_size = image_size - ) if unet.lowres_cond else None + lowres_cond_img = None + shape = (batch_size, channel, image_size, image_size) + + if unet.lowres_cond: + lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) + + if exists(vae): + image_size //= (2 ** vae.layers) + shape = (batch_size, vae.encoded_dim, image_size, image_size) + + if exists(lowres_cond_img): + lowres_cond_img = vae.encode(lowres_cond_img) img = self.p_sample_loop( unet, - (batch_size, channel, image_size, image_size), + shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img ) + if exists(vae): + img = vae.decode(img) + return img - def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): + def forward( + self, + image, + text = None, + image_embed = None, + text_encodings = None, + unet_number = None + ): assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' unet_number = default(unet_number, 1) unet_index = unet_number - 1 @@ -1350,6 +1398,7 @@ class Decoder(nn.Module): unet = self.get_unet(unet_number) target_image_size = self.image_sizes[unet_index] + vae = self.vaes[unet_index] b, c, h, w, device, = *image.shape, image.device @@ -1364,8 +1413,17 @@ class Decoder(nn.Module): text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None - ddpm_image = resize_image_to(image, target_image_size) - return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) + image = resize_image_to(image, target_image_size) + + if exists(vae): + vae.eval() + with torch.no_grad(): + image = vae.encode(image) + + if exists(lowres_cond_img): + lowres_cond_img = vae.encode(lowres_cond_img) + + return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) # main class diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py index 59a194f..953cba0 100644 --- a/dalle2_pytorch/vqgan_vae.py +++ b/dalle2_pytorch/vqgan_vae.py @@ -294,7 +294,7 @@ class VQGanVAE(nn.Module): dim, image_size, channels = 3, - num_layers = 4, + layers = 4, layer_mults = None, l2_recon_loss = False, use_hinge_loss = True, @@ -321,35 +321,37 @@ class VQGanVAE(nn.Module): self.image_size = image_size self.channels = channels - self.num_layers = num_layers - self.fmap_size = image_size // (num_layers ** 2) + self.layers = layers + self.fmap_size = image_size // (layers ** 2) self.codebook_size = vq_codebook_size self.encoders = MList([]) self.decoders = MList([]) - layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers)))) - assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers' + layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers)))) + assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers' layer_dims = [dim * mult for mult in layer_mults] dims = (dim, *layer_dims) codebook_dim = layer_dims[-1] + self.encoded_dim = dims[-1] + dim_pairs = zip(dims[:-1], dims[1:]) append = lambda arr, t: arr.append(t) prepend = lambda arr, t: arr.insert(0, t) if not isinstance(num_resnet_blocks, tuple): - num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks) + num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) if not isinstance(use_attn, tuple): - use_attn = (*((False,) * (num_layers - 1)), use_attn) + use_attn = (*((False,) * (layers - 1)), use_attn) - assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers' - assert len(use_attn) == num_layers + assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers' + assert len(use_attn) == layers - for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn): + for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu())) @@ -434,12 +436,15 @@ class VQGanVAE(nn.Module): return fmap - def decode(self, fmap): + def decode(self, fmap, return_indices_and_loss = False): fmap, indices, commit_loss = self.vq(fmap) for dec in self.decoders: fmap = dec(fmap) + if not return_indices_and_loss: + return fmap + return fmap, indices, commit_loss def forward( @@ -455,7 +460,7 @@ class VQGanVAE(nn.Module): fmap = self.encode(img) - fmap, indices, commit_loss = self.decode(fmap) + fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True) if not return_loss and not return_discr_loss: return fmap diff --git a/setup.py b/setup.py index 4914bed..c9360b6 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.36', + version = '0.0.37', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',