diff --git a/README.md b/README.md
index b2fc789..b4344e6 100644
--- a/README.md
+++ b/README.md
@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now.
-It may also explore an extension of using latent diffusion in the decoder from Rombach et al.
-
Please join
if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. Placeholder repository. I will also eventually extend this to text to video, once the repository is in a good place.
@@ -385,6 +383,117 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
+## Experimental
+
+### DALL-E2 with Latent Diffusion
+
+This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al.
+
+You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
+
+```python
+import torch
+from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
+
+# trained clip from step 1
+
+clip = CLIP(
+ dim_text = 512,
+ dim_image = 512,
+ dim_latent = 512,
+ num_text_tokens = 49408,
+ text_enc_depth = 1,
+ text_seq_len = 256,
+ text_heads = 8,
+ visual_enc_depth = 1,
+ visual_image_size = 256,
+ visual_patch_size = 32,
+ visual_heads = 8
+)
+
+# 2 unets for the decoder (a la cascading DDPM)
+
+# 1st unet is doing latent diffusion
+
+vae1 = VQGanVAE(
+ dim = 32,
+ image_size = 256,
+ layers = 3,
+ layer_mults = (1, 2, 4)
+)
+
+vae2 = VQGanVAE(
+ dim = 32,
+ image_size = 512,
+ layers = 3,
+ layer_mults = (1, 2, 4)
+)
+
+unet1 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ cond_dim = 128,
+ channels = 3,
+ sparse_attn = True,
+ sparse_attn_window = 2,
+ dim_mults = (1, 2, 4, 8)
+)
+
+unet2 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16),
+ cond_on_image_embeds = True,
+ cond_on_text_encodings = False
+)
+
+unet3 = Unet(
+ dim = 32,
+ image_embed_dim = 512,
+ channels = 3,
+ dim_mults = (1, 2, 4, 8, 16),
+ cond_on_image_embeds = True,
+ cond_on_text_encodings = False,
+ attend_at_middle = False
+)
+
+# decoder, which contains the unet(s) and clip
+
+decoder = Decoder(
+ clip = clip,
+ vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
+ unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
+ image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
+ timesteps = 100,
+ cond_drop_prob = 0.2
+).cuda()
+
+# mock images (get a lot of this)
+
+images = torch.randn(1, 3, 1024, 1024).cuda()
+
+# feed images into decoder, specifying which unet you want to train
+# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
+
+with decoder.one_unet_in_gpu(1):
+ loss = decoder(images, unet_number = 1)
+ loss.backward()
+
+with decoder.one_unet_in_gpu(2):
+ loss = decoder(images, unet_number = 2)
+ loss.backward()
+
+# do the above for many steps for both unets
+
+# then it will learn to generate images based on the CLIP image embeddings
+
+# chaining the unets from lowest resolution to highest resolution (thus cascading)
+
+mock_image_embed = torch.randn(1, 512).cuda()
+images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
+```
+
## CLI Usage (work in progress)
```bash
@@ -412,11 +521,13 @@ Offer training wrappers
- [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
-- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
+- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
+- [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
+- [ ] bring in tools to train vqgan-vae
## Citations
diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py
index 41cbab4..2b27df3 100644
--- a/dalle2_pytorch/dalle2_pytorch.py
+++ b/dalle2_pytorch/dalle2_pytorch.py
@@ -48,6 +48,12 @@ def is_list_str(x):
return False
return all([type(el) == str for el in x])
+def pad_tuple_to_length(t, length):
+ remain_length = length - len(t)
+ if remain_length <= 0:
+ return t
+ return (*t, *((None,) * remain_length))
+
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -540,12 +546,14 @@ class DiffusionPrior(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
+ @torch.no_grad()
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
+ @torch.no_grad()
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
@@ -940,11 +948,16 @@ class Unet(nn.Module):
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
- def force_lowres_cond(self, lowres_cond):
- if lowres_cond == self.lowres_cond:
+ def cast_model_parameters(
+ self,
+ *,
+ lowres_cond,
+ channels
+ ):
+ if lowres_cond == self.lowres_cond and channels == self.channels:
return self
- updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
+ updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
@@ -1100,6 +1113,7 @@ class Decoder(nn.Module):
unet,
*,
clip,
+ vae = None,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
@@ -1120,11 +1134,25 @@ class Decoder(nn.Module):
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
+ unets = cast_tuple(unet)
+ vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
+
self.unets = nn.ModuleList([])
- for ind, one_unet in enumerate(cast_tuple(unet)):
+ self.vaes = nn.ModuleList([])
+
+ for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
is_first = ind == 0
- one_unet = one_unet.force_lowres_cond(not is_first)
+ latent_dim = one_vae.encoded_dim if exists(one_vae) else None
+
+ unet_channels = default(latent_dim, self.channels)
+
+ one_unet = one_unet.cast_model_parameters(
+ lowres_cond = not is_first,
+ channels = unet_channels
+ )
+
self.unets.append(one_unet)
+ self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
# unet image sizes
@@ -1219,10 +1247,12 @@ class Decoder(nn.Module):
yield
unet.cpu()
+ @torch.no_grad()
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
+ @torch.no_grad()
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
@@ -1324,25 +1354,43 @@ class Decoder(nn.Module):
img = None
- for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
+ for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet):
- lowres_cond_img = self.to_lowres_cond(
- img,
- target_image_size = image_size
- ) if unet.lowres_cond else None
+ lowres_cond_img = None
+ shape = (batch_size, channel, image_size, image_size)
+
+ if unet.lowres_cond:
+ lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
+
+ if exists(vae):
+ image_size //= (2 ** vae.layers)
+ shape = (batch_size, vae.encoded_dim, image_size, image_size)
+
+ if exists(lowres_cond_img):
+ lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop(
unet,
- (batch_size, channel, image_size, image_size),
+ shape,
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img
)
+ if exists(vae):
+ img = vae.decode(img)
+
return img
- def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
+ def forward(
+ self,
+ image,
+ text = None,
+ image_embed = None,
+ text_encodings = None,
+ unet_number = None
+ ):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
unet_index = unet_number - 1
@@ -1350,6 +1398,7 @@ class Decoder(nn.Module):
unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index]
+ vae = self.vaes[unet_index]
b, c, h, w, device, = *image.shape, image.device
@@ -1364,8 +1413,17 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
- ddpm_image = resize_image_to(image, target_image_size)
- return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
+ image = resize_image_to(image, target_image_size)
+
+ if exists(vae):
+ vae.eval()
+ with torch.no_grad():
+ image = vae.encode(image)
+
+ if exists(lowres_cond_img):
+ lowres_cond_img = vae.encode(lowres_cond_img)
+
+ return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class
diff --git a/dalle2_pytorch/vqgan_vae.py b/dalle2_pytorch/vqgan_vae.py
index 59a194f..953cba0 100644
--- a/dalle2_pytorch/vqgan_vae.py
+++ b/dalle2_pytorch/vqgan_vae.py
@@ -294,7 +294,7 @@ class VQGanVAE(nn.Module):
dim,
image_size,
channels = 3,
- num_layers = 4,
+ layers = 4,
layer_mults = None,
l2_recon_loss = False,
use_hinge_loss = True,
@@ -321,35 +321,37 @@ class VQGanVAE(nn.Module):
self.image_size = image_size
self.channels = channels
- self.num_layers = num_layers
- self.fmap_size = image_size // (num_layers ** 2)
+ self.layers = layers
+ self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size
self.encoders = MList([])
self.decoders = MList([])
- layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
- assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
+ layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
+ assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1]
+ self.encoded_dim = dims[-1]
+
dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple):
- num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
+ num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple):
- use_attn = (*((False,) * (num_layers - 1)), use_attn)
+ use_attn = (*((False,) * (layers - 1)), use_attn)
- assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
- assert len(use_attn) == num_layers
+ assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
+ assert len(use_attn) == layers
- for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
+ for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
@@ -434,12 +436,15 @@ class VQGanVAE(nn.Module):
return fmap
- def decode(self, fmap):
+ def decode(self, fmap, return_indices_and_loss = False):
fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders:
fmap = dec(fmap)
+ if not return_indices_and_loss:
+ return fmap
+
return fmap, indices, commit_loss
def forward(
@@ -455,7 +460,7 @@ class VQGanVAE(nn.Module):
fmap = self.encode(img)
- fmap, indices, commit_loss = self.decode(fmap)
+ fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss:
return fmap
diff --git a/setup.py b/setup.py
index 4914bed..c9360b6 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
- version = '0.0.36',
+ version = '0.0.37',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',