Compare commits

...

12 Commits

5 changed files with 339 additions and 69 deletions

145
README.md
View File

@@ -10,8 +10,6 @@ The main novelty seems to be an extra layer of indirection with the prior networ
This model is SOTA for text-to-image for now. This model is SOTA for text-to-image for now.
It may also explore an extension of using <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a> in the decoder from Rombach et al.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
@@ -385,7 +383,127 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress) ## Experimental
### DALL-E2 with Latent Diffusion
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP, VQGanVAE
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
)
# 3 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE(
dim = 32,
image_size = 256,
layers = 3,
layer_mults = (1, 2, 4)
)
vae2 = VQGanVAE(
dim = 32,
image_size = 512,
layers = 3,
layer_mults = (1, 2, 4)
)
unet1 = Unet(
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
sparse_attn = True,
sparse_attn_window = 2,
dim_mults = (1, 2, 4, 8)
)
unet2 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False
)
unet3 = Unet(
dim = 32,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
cond_on_image_embeds = True,
cond_on_text_encodings = False,
attend_at_middle = False
)
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
vae = (vae1, vae2), # latent diffusion for unet1 (vae1) and unet2 (vae2), but not for the last unet3
unet = (unet1, unet2, unet3), # insert unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512, 1024), # resolutions, 256 for first unet, 512 for second, 1024 for third
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
with decoder.one_unet_in_gpu(1):
loss = decoder(images, unet_number = 1)
loss.backward()
with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2)
loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
```
## Training wrapper (wip)
Offer training wrappers
## CLI (wip)
```bash ```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog' $ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -393,9 +511,7 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip) <a href="https://github.com/lucidrains/big-sleep">template</a>
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
@@ -412,11 +528,15 @@ Offer training wrappers
- [x] add efficient attention in unet - [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional - [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
## Citations ## Citations
@@ -454,17 +574,6 @@ Offer training wrappers
} }
``` ```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex ```bibtex
@inproceedings{Tu2022MaxViTMV, @inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer}, title = {MaxViT: Multi-Axis Vision Transformer},

View File

@@ -1,9 +1,51 @@
import click import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main(): def main():
pass pass
@click.command() @click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text') @click.argument('text')
def dream(text): def dream(
return 'not ready yet' model,
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

View File

@@ -16,6 +16,7 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip # use x-clip
@@ -48,6 +49,12 @@ def is_list_str(x):
return False return False
return all([type(el) == str for el in x]) return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue = None):
remain_length = length - len(t)
if remain_length <= 0:
return t
return (*t, *((fillvalue,) * remain_length))
# for controlling freezing of CLIP # for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad): def set_module_requires_grad_(module, requires_grad):
@@ -540,12 +547,14 @@ class DiffusionPrior(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0] image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls) image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed) return l2norm(image_embed)
@torch.no_grad()
def get_text_cond(self, text): def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
@@ -575,14 +584,16 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond)
if self.predict_x0: if self.predict_x0:
x_recon = self.net(x, t, **text_cond) x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else: else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond)) x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised: if clip_denoised and not self.predict_x0:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -940,11 +951,16 @@ class Unet(nn.Module):
# if the current settings for the unet are not correct # if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings # for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond): def cast_model_parameters(
if lowres_cond == self.lowres_cond: self,
*,
lowres_cond,
channels
):
if lowres_cond == self.lowres_cond and channels == self.channels:
return self return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond} updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond, 'channels': channels}
return self.__class__(**updated_kwargs) return self.__class__(**updated_kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
@@ -1100,10 +1116,12 @@ class Decoder(nn.Module):
unet, unet,
*, *,
clip, clip,
vae = None,
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2, cond_drop_prob = 0.2,
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x0 = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
@@ -1120,11 +1138,28 @@ class Decoder(nn.Module):
# automatically take care of ensuring that first unet is unconditional # automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels))
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)): self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0 is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first) latent_dim = one_vae.encoded_dim if exists(one_vae) else None
unet_channels = default(latent_dim, self.channels)
one_unet = one_unet.cast_model_parameters(
lowres_cond = not is_first,
channels = unet_channels
)
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval())
# unet image sizes # unet image sizes
@@ -1135,6 +1170,10 @@ class Decoder(nn.Module):
self.image_sizes = image_sizes self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes)) self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# predict x0 config
self.predict_x0 = cast_tuple(predict_x0, len(unets))
# cascading ddpm related stuff # cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -1219,10 +1258,12 @@ class Decoder(nn.Module):
yield yield
unet.cpu() unet.cpu()
@torch.no_grad()
def get_text_encodings(self, text): def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:] return text_encodings[:, 1:]
@torch.no_grad()
def get_image_embed(self, image): def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size) image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image) image_encoding = self.clip.visual_transformer(image)
@@ -1251,34 +1292,47 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x0 = False, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if clip_denoised: if predict_x0:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x0:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x0 = False, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x0 = predict_x0)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, predict_x0 = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) img = self.p_sample(
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x0 = predict_x0
)
return img return img
@@ -1290,7 +1344,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None): def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x0 = False, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1304,12 +1358,14 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob cond_drop_prob = self.cond_drop_prob
) )
target = noise if not predict_x0 else x_start
if self.loss_type == 'l1': if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon) loss = F.l1_loss(target, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon) loss = F.mse_loss(target, x_recon)
elif self.loss_type == "huber": elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon) loss = F.smooth_l1_loss(target, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()
@@ -1324,25 +1380,42 @@ class Decoder(nn.Module):
img = None img = None
for unet, channel, image_size in tqdm(zip(self.unets, self.sample_channels, self.image_sizes)): for unet, vae, channel, image_size, predict_x0 in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x0)):
with self.one_unet_in_gpu(unet = unet): with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = self.to_lowres_cond( lowres_cond_img = None
img, shape = (batch_size, channel, image_size, image_size)
target_image_size = image_size
) if unet.lowres_cond else None if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
image_size = vae.get_encoded_fmap_size(image_size)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
(batch_size, channel, image_size, image_size), shape,
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x0 = predict_x0,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
img = vae.decode(img)
return img return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): def forward(
self,
image,
text = None,
image_embed = None,
text_encodings = None,
unet_number = None
):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
unet_index = unet_number - 1 unet_index = unet_number - 1
@@ -1350,6 +1423,8 @@ class Decoder(nn.Module):
unet = self.get_unet(unet_number) unet = self.get_unet(unet_number)
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index]
predict_x0 = self.predict_x0[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -1364,8 +1439,16 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
vae.eval()
with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x0 = predict_x0)
# main class # main class
@@ -1392,6 +1475,7 @@ class DALLE2(nn.Module):
cond_scale = 1. cond_scale = 1.
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text): if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
@@ -1399,4 +1483,8 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale) images = self.decoder.sample(image_embed, cond_scale = cond_scale)
if one_text:
return images[0]
return images return images

View File

@@ -287,6 +287,28 @@ class VQGanAttention(nn.Module):
return out + residual return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module): class VQGanVAE(nn.Module):
def __init__( def __init__(
self, self,
@@ -294,7 +316,7 @@ class VQGanVAE(nn.Module):
dim, dim,
image_size, image_size,
channels = 3, channels = 3,
num_layers = 4, layers = 4,
layer_mults = None, layer_mults = None,
l2_recon_loss = False, l2_recon_loss = False,
use_hinge_loss = True, use_hinge_loss = True,
@@ -321,35 +343,37 @@ class VQGanVAE(nn.Module):
self.image_size = image_size self.image_size = image_size
self.channels = channels self.channels = channels
self.num_layers = num_layers self.layers = layers
self.fmap_size = image_size // (num_layers ** 2) self.fmap_size = image_size // (layers ** 2)
self.codebook_size = vq_codebook_size self.codebook_size = vq_codebook_size
self.encoders = MList([]) self.encoders = MList([])
self.decoders = MList([]) self.decoders = MList([])
layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(num_layers)))) layer_mults = default(layer_mults, list(map(lambda t: 2 ** t, range(layers))))
assert len(layer_mults) == num_layers, 'layer multipliers must be equal to designated number of layers' assert len(layer_mults) == layers, 'layer multipliers must be equal to designated number of layers'
layer_dims = [dim * mult for mult in layer_mults] layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims) dims = (dim, *layer_dims)
codebook_dim = layer_dims[-1] codebook_dim = layer_dims[-1]
self.encoded_dim = dims[-1]
dim_pairs = zip(dims[:-1], dims[1:]) dim_pairs = zip(dims[:-1], dims[1:])
append = lambda arr, t: arr.append(t) append = lambda arr, t: arr.append(t)
prepend = lambda arr, t: arr.insert(0, t) prepend = lambda arr, t: arr.insert(0, t)
if not isinstance(num_resnet_blocks, tuple): if not isinstance(num_resnet_blocks, tuple):
num_resnet_blocks = (*((0,) * (num_layers - 1)), num_resnet_blocks) num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
if not isinstance(use_attn, tuple): if not isinstance(use_attn, tuple):
use_attn = (*((False,) * (num_layers - 1)), use_attn) use_attn = (*((False,) * (layers - 1)), use_attn)
assert len(num_resnet_blocks) == num_layers, 'number of resnet blocks config must be equal to number of layers' assert len(num_resnet_blocks) == layers, 'number of resnet blocks config must be equal to number of layers'
assert len(use_attn) == num_layers assert len(use_attn) == layers
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(num_layers), dim_pairs, num_resnet_blocks, use_attn): for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu())) prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
@@ -405,6 +429,9 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self): def copy_for_eval(self):
device = next(self.parameters()).device device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu()) vae_copy = copy.deepcopy(self.cpu())
@@ -434,28 +461,32 @@ class VQGanVAE(nn.Module):
return fmap return fmap
def decode(self, fmap): def decode(self, fmap, return_indices_and_loss = False):
fmap = self.vq(fmap) fmap, indices, commit_loss = self.vq(fmap)
for dec in self.decoders: for dec in self.decoders:
fmap = dec(fmap) fmap = dec(fmap)
return fmap if not return_indices_and_loss:
return fmap
return fmap, indices, commit_loss
def forward( def forward(
self, self,
img, img,
return_loss = False, return_loss = False,
return_discr_loss = False, return_discr_loss = False,
return_recons = False return_recons = False,
add_gradient_penalty = True
): ):
batch, channels, height, width, device = *img.shape, img.device batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE' assert channels == self.channels, 'number of channels on image or sketch is not equal to the channels set on this VQGanVAE'
fmap, indices, commit_loss = self.encode(img) fmap = self.encode(img)
fmap = self.decode(fmap) fmap, indices, commit_loss = self.decode(fmap, return_indices_and_loss = True)
if not return_loss and not return_discr_loss: if not return_loss and not return_discr_loss:
return fmap return fmap
@@ -472,11 +503,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
loss = discr_loss + gp if add_gradient_penalty:
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons: if return_recons:
return loss, fmap return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.35', version = '0.0.41',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',