mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder.
This commit is contained in:
117
README.md
117
README.md
@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
|
||||
|
||||
This model is SOTA for text-to-image for now.
|
||||
|
||||
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
@@ -385,6 +383,117 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Experimental
|
||||
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
|
||||
|
||||
# trained clip from step 1
|
||||
|
||||
clip = CLIP(
|
||||
dim_text = 512,
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# 1st unet is doing latent diffusion
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 256,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
vae2 = VQGanVAE(
|
||||
dim = 32,
|
||||
image_size = 512,
|
||||
layers = 3,
|
||||
layer_mults = (1, 2, 4)
|
||||
)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
sparse_attn = True,
|
||||
sparse_attn_window = 2,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
)
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False
|
||||
)
|
||||
|
||||
unet3 = Unet(
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
cond_on_image_embeds = True,
|
||||
cond_on_text_encodings = False,
|
||||
attend_at_middle = False
|
||||
)
|
||||
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
|
||||
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
|
||||
with decoder.one_unet_in_gpu(1):
|
||||
loss = decoder(images, unet_number = 1)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
|
||||
# chaining the unets from lowest resolution to highest resolution (thus cascading)
|
||||
|
||||
mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
|
||||
```bash
|
||||
@@ -412,11 +521,13 @@ Offer training wrappers
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
|
||||
- [ ] spend one day cleaning up tech debt in decoder
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -48,6 +48,12 @@ def is_list_str(x):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
def pad_tuple_to_length(t, length):
|
||||
remain_length = length - len(t)
|
||||
if remain_length <= 0:
|
||||
return t
|
||||
return (*t, *((None,) * remain_length))
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
def set_module_requires_grad_(module, requires_grad):
|
||||
@@ -540,12 +546,14 @@ class DiffusionPrior(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
image_cls = image_encoding[:, 0]
|
||||
image_embed = self.clip.to_visual_latent(image_cls)
|
||||
return l2norm(image_embed)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_cond(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
|
||||
@@ -940,11 +948,16 @@ class Unet(nn.Module):
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
def cast_model_parameters(
|
||||
self,
|
||||
*,
|
||||
lowres_cond,
|
||||
channels
|
||||
):
|
||||
if lowres_cond == self.lowres_cond and channels == self.channels:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
@@ -1100,6 +1113,7 @@ class Decoder(nn.Module):
|
||||
unet,
|
||||
*,
|
||||
clip,
|
||||
vae = None,
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2,
|
||||
loss_type = 'l1',
|
||||
@@ -1120,11 +1134,25 @@ class Decoder(nn.Module):
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
unet_channels = default(latent_dim, self.channels)
|
||||
|
||||
one_unet = one_unet.cast_model_parameters(
|
||||
lowres_cond = not is_first,
|
||||
channels = unet_channels
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
@@ -1219,10 +1247,12 @@ class Decoder(nn.Module):
|
||||
yield
|
||||
unet.cpu()
|
||||
|
||||
@torch.no_grad()
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
|
||||
@torch.no_grad()
|
||||
def get_image_embed(self, image):
|
||||
image = resize_image_to(image, self.clip_image_size)
|
||||
image_encoding = self.clip.visual_transformer(image)
|
||||
@@ -1324,25 +1354,43 @@ class Decoder(nn.Module):
|
||||
|
||||
img = None
|
||||
|
||||
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)):
|
||||
for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
|
||||
with self.one_unet_in_gpu(unet = unet):
|
||||
lowres_cond_img = self.to_lowres_cond(
|
||||
img,
|
||||
target_image_size = image_size
|
||||
) if unet.lowres_cond else None
|
||||
lowres_cond_img = None
|
||||
shape = (batch_size, channel, image_size, image_size)
|
||||
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
if exists(vae):
|
||||
image_size //= (2 ** vae.layers)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
(batch_size, channel, image_size, image_size),
|
||||
shape,
|
||||
image_embed = image_embed,
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = cond_scale,
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
if exists(vae):
|
||||
img = vae.decode(img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
def forward(
|
||||
self,
|
||||
image,
|
||||
text = None,
|
||||
image_embed = None,
|
||||
text_encodings = None,
|
||||
unet_number = None
|
||||
):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
unet_index = unet_number - 1
|
||||
@@ -1350,6 +1398,7 @@ class Decoder(nn.Module):
|
||||
unet = self.get_unet(unet_number)
|
||||
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
vae = self.vaes[unet_index]
|
||||
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
|
||||
@@ -1364,8 +1413,17 @@ class Decoder(nn.Module):
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
if exists(vae):
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
# main class
|
||||
|
||||
|
||||
@@ -294,7 +294,7 @@ class VQGanVAE(nn.Module):
|
||||
dim,
|
||||
image_size,
|
||||
channels = 3,
|
||||
num_layers = 4,
|
||||
layers = 4,
|
||||
layer_mults = None,
|
||||
l2_recon_loss = False,
|
||||
use_hinge_loss = True,
|
||||
@@ -321,35 +321,37 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
self.image_size = image_size
|
||||
self.channels = channels
|
||||
self.num_layers = num_layers
|
||||
self.fmap_size = image_size // (num_layers ** 2)
|
||||
self.layers = layers
|
||||
self.fmap_size = image_size // (layers ** 2)
|
||||
self.codebook_size = vq_codebook_size
|
||||
|
||||
self.encoders = MList([])
|
||||
self.decoders = MList([])
|
||||
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers))))
|
||||
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers'
|
||||
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
|
||||
assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
|
||||
|
||||
layer_dims = [dim * mult for mult in layer_mults]
|
||||
dims = (dim, *layer_dims)
|
||||
codebook_dim = layer_dims[-1]
|
||||
|
||||
self.encoded_dim = dims[-1]
|
||||
|
||||
dim_pairs = zip(dims[:-1], dims[1:])
|
||||
|
||||
append = lambda arr, t: arr.append(t)
|
||||
prepend = lambda arr, t: arr.insert(0, t)
|
||||
|
||||
if not isinstance(num_resnet_blocks, tuple):
|
||||
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks)
|
||||
num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
|
||||
|
||||
if not isinstance(use_attn, tuple):
|
||||
use_attn = (*((False,) * (num_layers - 1)), use_attn)
|
||||
use_attn = (*((False,) * (layers - 1)), use_attn)
|
||||
|
||||
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == num_layers
|
||||
assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
|
||||
assert len(use_attn) == layers
|
||||
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
|
||||
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
|
||||
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
|
||||
|
||||
@@ -434,12 +436,15 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
return fmap
|
||||
|
||||
def decode(self, fmap):
|
||||
def decode(self, fmap, return_indices_and_loss = False):
|
||||
fmap, indices, commit_loss = self.vq(fmap)
|
||||
|
||||
for dec in self.decoders:
|
||||
fmap = dec(fmap)
|
||||
|
||||
if not return_indices_and_loss:
|
||||
return fmap
|
||||
|
||||
return fmap, indices, commit_loss
|
||||
|
||||
def forward(
|
||||
@@ -455,7 +460,7 @@ class VQGanVAE(nn.Module):
|
||||
|
||||
fmap = self.encode(img)
|
||||
|
||||
fmap, indices, commit_loss = self.decode(fmap)
|
||||
fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
|
||||
|
||||
if not return_loss and not return_discr_loss:
|
||||
return fmap
|
||||
|
||||
Reference in New Issue
Block a user