Compare commits

..

5 Commits

Author SHA1 Message Date
Phil Wang
05b74be69a use null container pattern to cleanup some conditionals, save more cleanup for next week 2022-04-22 15:23:18 -07:00
Phil Wang
a8b5d5d753 last tweak of readme 2022-04-22 14:16:43 -07:00
Phil Wang
976ef7f87c project management 2022-04-22 14:15:42 -07:00
Phil Wang
fd175bcc0e readme 2022-04-22 14:13:33 -07:00
Phil Wang
76b32f18b3 first pass at complete DALL-E2 + Latent Diffusion integration, latent diffusion on any layer(s) of the cascading ddpm in the decoder. 2022-04-22 13:53:13 -07:00
4 changed files with 64 additions and 27 deletions

View File

@@ -383,9 +383,11 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Experimental - DALL-E2 with Latent Diffusion ## Experimental
This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al. ### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish. You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
@@ -409,9 +411,10 @@ clip = CLIP(
visual_heads = 8 visual_heads = 8
) )
# 2 unets for the decoder (a la cascading DDPM) # 3 unets for the decoder (a la cascading DDPM)
# 1st unet is doing latent diffusion # first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE( vae1 = VQGanVAE(
dim = 32, dim = 32,
@@ -469,7 +472,7 @@ decoder = Decoder(
# mock images (get a lot of this) # mock images (get a lot of this)
images = torch.randn(1, 3, 512, 512).cuda() images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train # feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme # each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
@@ -482,6 +485,10 @@ with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2) loss = decoder(images, unet_number = 2)
loss.backward() loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets # do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings # then it will learn to generate images based on the CLIP image embeddings
@@ -492,7 +499,11 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## CLI Usage (work in progress) ## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash ```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog' $ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -500,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip) <a href="https://github.com/lucidrains/big-sleep">template</a>
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
@@ -525,6 +534,8 @@ Offer training wrappers
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
## Citations ## Citations

View File

@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip # use x-clip
@@ -48,11 +49,11 @@ def is_list_str(x):
return False return False
return all([type(el) == str for el in x]) return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length): def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t) remain_length = length - len(t)
if remain_length <= 0: if remain_length <= 0:
return t return t
return (*t, *((None,) * remain_length)) return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP # for controlling freezing of CLIP
@@ -1135,12 +1136,15 @@ class Decoder(nn.Module):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0 is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
@@ -1152,7 +1156,7 @@ class Decoder(nn.Module):
) )
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None) self.vaes.append(one_vae.copy_for_eval())
# unet image sizes # unet image sizes
@@ -1362,12 +1366,11 @@ class Decoder(nn.Module):
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
if exists(vae): image_size = vae.get_encoded_fmap_size(image_size)
image_size //= (2 ** vae.layers) shape = (batch_size, vae.encoded_dim, image_size, image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
@@ -1378,8 +1381,7 @@ class Decoder(nn.Module):
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
if exists(vae): img = vae.decode(img)
img = vae.decode(img)
return img return img
@@ -1415,13 +1417,12 @@ class Decoder(nn.Module):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
if exists(vae): vae.eval()
vae.eval() with torch.no_grad():
with torch.no_grad(): image = vae.encode(image)
image = vae.encode(image)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

View File

@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
return out + residual return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module): class VQGanVAE(nn.Module):
def __init__( def __init__(
self, self,
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self): def copy_for_eval(self):
device = next(self.parameters()).device device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu()) vae_copy = copy.deepcopy(self.cpu())

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.37', version = '0.0.38',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',