mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05b74be69a | ||
|
|
a8b5d5d753 | ||
|
|
976ef7f87c | ||
|
|
fd175bcc0e | ||
|
|
76b32f18b3 |
29
README.md
29
README.md
@@ -383,9 +383,11 @@ You can also train the decoder on images of greater than the size (say 512x512)
|
||||
|
||||
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
|
||||
|
||||
## Experimental - DALL-E2 with Latent Diffusion
|
||||
## Experimental
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al.
|
||||
### DALL-E2 with Latent Diffusion
|
||||
|
||||
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
|
||||
|
||||
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
|
||||
|
||||
@@ -409,9 +411,10 @@ clip = CLIP(
|
||||
visual_heads = 8
|
||||
)
|
||||
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
# 3 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
# 1st unet is doing latent diffusion
|
||||
# first two unets are doing latent diffusion
|
||||
# vqgan-vae must be trained before hand
|
||||
|
||||
vae1 = VQGanVAE(
|
||||
dim = 32,
|
||||
@@ -469,7 +472,7 @@ decoder = Decoder(
|
||||
|
||||
# mock images (get a lot of this)
|
||||
|
||||
images = torch.randn(1, 3, 512, 512).cuda()
|
||||
images = torch.randn(1, 3, 1024, 1024).cuda()
|
||||
|
||||
# feed images into decoder, specifying which unet you want to train
|
||||
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
|
||||
@@ -482,6 +485,10 @@ with decoder.one_unet_in_gpu(2):
|
||||
loss = decoder(images, unet_number = 2)
|
||||
loss.backward()
|
||||
|
||||
with decoder.one_unet_in_gpu(3):
|
||||
loss = decoder(images, unet_number = 3)
|
||||
loss.backward()
|
||||
|
||||
# do the above for many steps for both unets
|
||||
|
||||
# then it will learn to generate images based on the CLIP image embeddings
|
||||
@@ -492,7 +499,11 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
||||
```
|
||||
|
||||
## CLI Usage (work in progress)
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
|
||||
## CLI (wip)
|
||||
|
||||
```bash
|
||||
$ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
@@ -500,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
|
||||
|
||||
Once built, images will be saved to the same directory the command is invoked
|
||||
|
||||
## Training wrapper (wip)
|
||||
|
||||
Offer training wrappers
|
||||
<a href="https://github.com/lucidrains/big-sleep">template</a>
|
||||
|
||||
## Training CLI (wip)
|
||||
|
||||
@@ -525,6 +534,8 @@ Offer training wrappers
|
||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
|
||||
- [ ] bring in tools to train vqgan-vae
|
||||
- [ ] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
|
||||
|
||||
## Citations
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
|
||||
|
||||
# use x-clip
|
||||
|
||||
@@ -48,11 +49,11 @@ def is_list_str(x):
|
||||
return False
|
||||
return all([type(el) == str for el in x])
|
||||
|
||||
def pad_tuple_to_length(t, length):
|
||||
def pad_tuple_to_length(t, length, fillvalue = None):
|
||||
remain_length = length - len(t)
|
||||
if remain_length <= 0:
|
||||
return t
|
||||
return (*t, *((None,) * remain_length))
|
||||
return (*t, *((fillvalue,) * remain_length))
|
||||
|
||||
# for controlling freezing of CLIP
|
||||
|
||||
@@ -1135,12 +1136,15 @@ class Decoder(nn.Module):
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
unets = cast_tuple(unet)
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
|
||||
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
self.vaes = nn.ModuleList([])
|
||||
|
||||
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
|
||||
assert isinstance(one_unet, Unet)
|
||||
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
|
||||
|
||||
is_first = ind == 0
|
||||
latent_dim = one_vae.encoded_dim if exists(one_vae) else None
|
||||
|
||||
@@ -1152,7 +1156,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
|
||||
self.unets.append(one_unet)
|
||||
self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
|
||||
self.vaes.append(one_vae.copy_for_eval())
|
||||
|
||||
# unet image sizes
|
||||
|
||||
@@ -1362,12 +1366,11 @@ class Decoder(nn.Module):
|
||||
if unet.lowres_cond:
|
||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||
|
||||
if exists(vae):
|
||||
image_size //= (2 ** vae.layers)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
image_size = vae.get_encoded_fmap_size(image_size)
|
||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
img = self.p_sample_loop(
|
||||
unet,
|
||||
@@ -1378,8 +1381,7 @@ class Decoder(nn.Module):
|
||||
lowres_cond_img = lowres_cond_img
|
||||
)
|
||||
|
||||
if exists(vae):
|
||||
img = vae.decode(img)
|
||||
img = vae.decode(img)
|
||||
|
||||
return img
|
||||
|
||||
@@ -1415,13 +1417,12 @@ class Decoder(nn.Module):
|
||||
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
|
||||
image = resize_image_to(image, target_image_size)
|
||||
|
||||
if exists(vae):
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
vae.eval()
|
||||
with torch.no_grad():
|
||||
image = vae.encode(image)
|
||||
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
if exists(lowres_cond_img):
|
||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||
|
||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
|
||||
|
||||
|
||||
@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
|
||||
|
||||
return out + residual
|
||||
|
||||
class NullVQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
channels
|
||||
):
|
||||
super().__init__()
|
||||
self.encoded_dim = channels
|
||||
self.layers = 0
|
||||
|
||||
def get_encoded_fmap_size(self, size):
|
||||
return size
|
||||
|
||||
def copy_for_eval(self):
|
||||
return self
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
class VQGanVAE(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -407,6 +429,9 @@ class VQGanVAE(nn.Module):
|
||||
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
|
||||
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
|
||||
|
||||
def get_encoded_fmap_size(self, image_size):
|
||||
return image_size // (2 ** self.layers)
|
||||
|
||||
def copy_for_eval(self):
|
||||
device = next(self.parameters()).device
|
||||
vae_copy = copy.deepcopy(self.cpu())
|
||||
|
||||
Reference in New Issue
Block a user