Compare commits

..

1 Commits

5 changed files with 75 additions and 286 deletions

View File

@@ -12,7 +12,7 @@ This model is SOTA for text-to-image for now.
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a <a href="https://github.com/lucidrains/dalle2-jax">Jax version</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place. There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install ## Install
@@ -246,6 +246,13 @@ loss = decoder(images, unet_number = 2)
loss.backward() loss.backward()
# do the above for many steps for both unets # do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
``` ```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s)) Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
@@ -376,11 +383,9 @@ You can also train the decoder on images of greater than the size (say 512x512)
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training. For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Experimental ## Experimental - DALL-E2 with Latent Diffusion
### DALL-E2 with Latent Diffusion This repository decides to take the next step and offer DALL-E2 combined with latent diffusion, from Rombach et al.
This repository decides to take the next step and offer DALL-E2 combined with <a href="https://huggingface.co/spaces/multimodalart/latentdiffusion">latent diffusion</a>, from Rombach et al.
You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish. You can use it as follows. Latent diffusion can be limited to just the first U-Net in the cascade, or to any number you wish.
@@ -404,10 +409,9 @@ clip = CLIP(
visual_heads = 8 visual_heads = 8
) )
# 3 unets for the decoder (a la cascading DDPM) # 2 unets for the decoder (a la cascading DDPM)
# first two unets are doing latent diffusion # 1st unet is doing latent diffusion
# vqgan-vae must be trained before hand
vae1 = VQGanVAE( vae1 = VQGanVAE(
dim = 32, dim = 32,
@@ -465,7 +469,7 @@ decoder = Decoder(
# mock images (get a lot of this) # mock images (get a lot of this)
images = torch.randn(1, 3, 1024, 1024).cuda() images = torch.randn(1, 3, 512, 512).cuda()
# feed images into decoder, specifying which unet you want to train # feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme # each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
@@ -478,10 +482,6 @@ with decoder.one_unet_in_gpu(2):
loss = decoder(images, unet_number = 2) loss = decoder(images, unet_number = 2)
loss.backward() loss.backward()
with decoder.one_unet_in_gpu(3):
loss = decoder(images, unet_number = 3)
loss.backward()
# do the above for many steps for both unets # do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings # then it will learn to generate images based on the CLIP image embeddings
@@ -492,11 +492,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024) images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
``` ```
## Training wrapper (wip) ## CLI Usage (work in progress)
Offer training wrappers
## CLI (wip)
```bash ```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog' $ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -504,7 +500,9 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked Once built, images will be saved to the same directory the command is invoked
<a href="https://github.com/lucidrains/big-sleep">template</a> ## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip) ## Training CLI (wip)
@@ -522,16 +520,11 @@ Once built, images will be saved to the same directory the command is invoked
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms - [x] build out latent diffusion architecture, with the vq-reg variant (vqgan-vae), make it completely optional and compatible with cascading ddpms
- [x] for decoder, allow ability to customize objective (predict epsilon vs x0), in case latent diffusion does better with prediction of x0
- [ ] spend one day cleaning up tech debt in decoder - [ ] spend one day cleaning up tech debt in decoder
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [ ] experiment with https://arxiv.org/abs/2112.11435 as upsampler, test in https://github.com/lucidrains/lightweight-gan first
## Citations ## Citations
@@ -563,12 +556,23 @@ Once built, images will be saved to the same directory the command is invoked
```bibtex ```bibtex
@inproceedings{Liu2022ACF, @inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020https://arxiv.org/abs/2112.11435s}, title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie}, author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022} year = {2022}
} }
``` ```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
```bibtex ```bibtex
@inproceedings{Tu2022MaxViTMV, @inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer}, title = {MaxViT: Multi-Axis Vision Transformer},
@@ -577,14 +581,4 @@ Once built, images will be saved to the same directory the command is invoked
} }
``` ```
```bibtex
@article{Arar2021LearnedQF,
title = {Learned Queries for Efficient Local Attention},
author = {Moab Arar and Ariel Shamir and Amit H. Bermano},
journal = {ArXiv},
year = {2021},
volume = {abs/2112.11435}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1,51 +1,9 @@
import click import click
import torch
import torchvision.transforms as T
from pathlib import Path
from dalle2_pytorch import DALLE2, Decoder, DiffusionPrior
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def simple_slugify(text, max_length = 255):
return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle2_pytorch').version
def main(): def main():
pass pass
@click.command() @click.command()
@click.option('--model', default = './dalle2.pt', help = 'path to trained DALL-E2 model')
@click.option('--cond_scale', default = 2, help = 'conditioning scale (classifier free guidance) in decoder')
@click.argument('text') @click.argument('text')
def dream( def dream(text):
model, return 'not ready yet'
cond_scale,
text
):
model_path = Path(model)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'model not found at {full_model_path}'
loaded = torch.load(str(model_path))
version = safeget(loaded, 'version')
print(f'loading DALL-E2 from {full_model_path}, saved at version {version} - current package version is {get_pkg_version()}')
prior_init_params = safeget(loaded, 'init_params.prior')
decoder_init_params = safeget(loaded, 'init_params.decoder')
model_params = safeget(loaded, 'model_params')
prior = DiffusionPrior(**prior_init_params)
decoder = Decoder(**decoder_init_params)
dalle2 = DALLE2(prior, decoder)
dalle2.load_state_dict(model_params)
image = dalle2(text, cond_scale = cond_scale)
pil_image = T.ToPILImage()(image)
return pil_image.save(f'./{simple_slugify(text)}.png')

View File

@@ -16,7 +16,6 @@ from einops_exts.torch import EinopsToAndFrom
from kornia.filters import gaussian_blur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
from dalle2_pytorch.vqgan_vae import NullVQGanVAE, VQGanVAE
# use x-clip # use x-clip
@@ -49,11 +48,11 @@ def is_list_str(x):
return False return False
return all([type(el) == str for el in x]) return all([type(el) == str for el in x])
def pad_tuple_to_length(t, length, fillvalue = None): def pad_tuple_to_length(t, length):
remain_length = length - len(t) remain_length = length - len(t)
if remain_length <= 0: if remain_length <= 0:
return t return t
return (*t, *((fillvalue,) * remain_length)) return (*t, *((None,) * remain_length))
# for controlling freezing of CLIP # for controlling freezing of CLIP
@@ -483,7 +482,7 @@ class DiffusionPrior(nn.Module):
timesteps = 1000, timesteps = 1000,
cond_drop_prob = 0.2, cond_drop_prob = 0.2,
loss_type = "l1", loss_type = "l1",
predict_x_start = True, predict_x0 = True,
beta_schedule = "cosine", beta_schedule = "cosine",
): ):
super().__init__() super().__init__()
@@ -497,7 +496,7 @@ class DiffusionPrior(nn.Module):
self.image_size = clip.image_size self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.predict_x_start = predict_x_start self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
if beta_schedule == "cosine": if beta_schedule == "cosine":
@@ -584,16 +583,14 @@ class DiffusionPrior(nn.Module):
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool): def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
pred = self.net(x, t, **text_cond) if self.predict_x0:
x_recon = self.net(x, t, **text_cond)
if self.predict_x_start:
x_recon = pred
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this # not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken # i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else: else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred) x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised and not self.predict_x_start: if clip_denoised:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
@@ -639,7 +636,7 @@ class DiffusionPrior(nn.Module):
**text_cond **text_cond
) )
to_predict = noise if not self.predict_x_start else image_embed to_predict = noise if not self.predict_x0 else image_embed
if self.loss_type == 'l1': if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon) loss = F.l1_loss(to_predict, x_recon)
@@ -1121,8 +1118,6 @@ class Decoder(nn.Module):
cond_drop_prob = 0.2, cond_drop_prob = 0.2,
loss_type = 'l1', loss_type = 'l1',
beta_schedule = 'cosine', beta_schedule = 'cosine',
predict_x_start = False,
predict_x_start_for_latent_diffusion = False,
image_sizes = None, # for cascading ddpm, image size at each stage image_sizes = None, # for cascading ddpm, image size at each stage
lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode lowres_cond_upsample_mode = 'bilinear', # cascading ddpm - low resolution upsample mode
lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur lowres_downsample_first = True, # cascading ddpm - resizes to lower resolution, then to next conditional resolution + blur
@@ -1140,15 +1135,12 @@ class Decoder(nn.Module):
# while the rest of the unets are conditioned on the low resolution image produced by previous unet # while the rest of the unets are conditioned on the low resolution image produced by previous unet
unets = cast_tuple(unet) unets = cast_tuple(unet)
vaes = pad_tuple_to_length(cast_tuple(vae), len(unets), fillvalue = NullVQGanVAE(channels = self.channels)) vaes = pad_tuple_to_length(cast_tuple(vae), len(unets))
self.unets = nn.ModuleList([]) self.unets = nn.ModuleList([])
self.vaes = nn.ModuleList([]) self.vaes = nn.ModuleList([])
for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)): for ind, (one_unet, one_vae) in enumerate(zip(unets, vaes)):
assert isinstance(one_unet, Unet)
assert isinstance(one_vae, (VQGanVAE, NullVQGanVAE))
is_first = ind == 0 is_first = ind == 0
latent_dim = one_vae.encoded_dim if exists(one_vae) else None latent_dim = one_vae.encoded_dim if exists(one_vae) else None
@@ -1160,7 +1152,7 @@ class Decoder(nn.Module):
) )
self.unets.append(one_unet) self.unets.append(one_unet)
self.vaes.append(one_vae.copy_for_eval()) self.vaes.append(one_vae.copy_for_eval() if exists(one_vae) else None)
# unet image sizes # unet image sizes
@@ -1171,10 +1163,6 @@ class Decoder(nn.Module):
self.image_sizes = image_sizes self.image_sizes = image_sizes
self.sample_channels = cast_tuple(self.channels, len(image_sizes)) self.sample_channels = cast_tuple(self.channels, len(image_sizes))
# predict x0 config
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
# cascading ddpm related stuff # cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -1293,47 +1281,34 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.): def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img) pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if predict_x_start: if clip_denoised:
x_recon = pred
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
if clip_denoised and not predict_x_start:
x_recon.clamp_(-1., 1.) x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad() @torch.no_grad()
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False): def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start) model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise) noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0 # no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad() @torch.no_grad()
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1): def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device device = self.betas.device
b = shape[0] b = shape[0]
img = torch.randn(shape, device = device) img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample( img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
unet,
img,
torch.full((b,), i, device = device, dtype = torch.long),
image_embed = image_embed,
text_encodings = text_encodings,
cond_scale = cond_scale,
lowres_cond_img = lowres_cond_img,
predict_x_start = predict_x_start
)
return img return img
@@ -1345,7 +1320,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
) )
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None): def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start)) noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise) x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -1359,14 +1334,12 @@ class Decoder(nn.Module):
cond_drop_prob = self.cond_drop_prob cond_drop_prob = self.cond_drop_prob
) )
target = noise if not predict_x_start else x_start
if self.loss_type == 'l1': if self.loss_type == 'l1':
loss = F.l1_loss(target, x_recon) loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2': elif self.loss_type == 'l2':
loss = F.mse_loss(target, x_recon) loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber": elif self.loss_type == "huber":
loss = F.smooth_l1_loss(target, x_recon) loss = F.smooth_l1_loss(noise, x_recon)
else: else:
raise NotImplementedError() raise NotImplementedError()
@@ -1381,7 +1354,7 @@ class Decoder(nn.Module):
img = None img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)): for unet, vae, channel, image_size in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes)):
with self.one_unet_in_gpu(unet = unet): with self.one_unet_in_gpu(unet = unet):
lowres_cond_img = None lowres_cond_img = None
shape = (batch_size, channel, image_size, image_size) shape = (batch_size, channel, image_size, image_size)
@@ -1389,11 +1362,12 @@ class Decoder(nn.Module):
if unet.lowres_cond: if unet.lowres_cond:
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size) lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
image_size = vae.get_encoded_fmap_size(image_size) if exists(vae):
shape = (batch_size, vae.encoded_dim, image_size, image_size) image_size //= (2 ** vae.layers)
shape = (batch_size, vae.encoded_dim, image_size, image_size)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
img = self.p_sample_loop( img = self.p_sample_loop(
unet, unet,
@@ -1401,11 +1375,11 @@ class Decoder(nn.Module):
image_embed = image_embed, image_embed = image_embed,
text_encodings = text_encodings, text_encodings = text_encodings,
cond_scale = cond_scale, cond_scale = cond_scale,
predict_x_start = predict_x_start,
lowres_cond_img = lowres_cond_img lowres_cond_img = lowres_cond_img
) )
img = vae.decode(img) if exists(vae):
img = vae.decode(img)
return img return img
@@ -1425,7 +1399,6 @@ class Decoder(nn.Module):
target_image_size = self.image_sizes[unet_index] target_image_size = self.image_sizes[unet_index]
vae = self.vaes[unet_index] vae = self.vaes[unet_index]
predict_x_start = self.predict_x_start[unet_index]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -1442,14 +1415,15 @@ class Decoder(nn.Module):
lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None lowres_cond_img = self.to_lowres_cond(image, target_image_size = target_image_size, downsample_image_size = self.image_sizes[unet_index - 1]) if unet_number > 1 else None
image = resize_image_to(image, target_image_size) image = resize_image_to(image, target_image_size)
vae.eval() if exists(vae):
with torch.no_grad(): vae.eval()
image = vae.encode(image) with torch.no_grad():
image = vae.encode(image)
if exists(lowres_cond_img): if exists(lowres_cond_img):
lowres_cond_img = vae.encode(lowres_cond_img) lowres_cond_img = vae.encode(lowres_cond_img)
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start) return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class # main class
@@ -1476,7 +1450,6 @@ class DALLE2(nn.Module):
cond_scale = 1. cond_scale = 1.
): ):
device = next(self.parameters()).device device = next(self.parameters()).device
one_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)
if isinstance(text, str) or is_list_str(text): if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text text = [text] if not isinstance(text, (list, tuple)) else text
@@ -1484,8 +1457,4 @@ class DALLE2(nn.Module):
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples) image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale) images = self.decoder.sample(image_embed, cond_scale = cond_scale)
if one_text:
return images[0]
return images return images

View File

@@ -243,112 +243,6 @@ class ResBlock(nn.Module):
def forward(self, x): def forward(self, x):
return self.net(x) + x return self.net(x) + x
# attention-based upsampling
# from https://arxiv.org/abs/2112.11435
class QueryAndAttend(nn.Module):
def __init__(
self,
*,
dim,
num_queries = 1,
dim_head = 32,
heads = 8,
window_size = 3
):
super().__init__()
self.scale = dim_head ** -0.5
inner_dim = dim_head * heads
self.heads = heads
self.dim_head = dim_head
self.window_size = window_size
self.num_queries = num_queries
self.rel_pos_bias = nn.Parameter(torch.randn(heads, num_queries, window_size * window_size, 1, 1))
self.queries = nn.Parameter(torch.randn(heads, num_queries, dim_head))
self.to_kv = nn.Conv2d(dim, dim_head * 2, 1, bias = False)
self.to_out = nn.Conv2d(inner_dim, dim, 1, bias = False)
def forward(self, x):
"""
einstein notation
b - batch
h - heads
l - num queries
d - head dimension
x - height
y - width
j - source sequence for attending to (kernel size squared in this case)
"""
wsz, heads, dim_head, num_queries = self.window_size, self.heads, self.dim_head, self.num_queries
batch, _, height, width = x.shape
is_one_query = self.num_queries == 1
# queries, keys, values
q = self.queries * self.scale
k, v = self.to_kv(x).chunk(2, dim = 1)
# similarities
sim = einsum('h l d, b d x y -> b h l x y', q, k)
sim = rearrange(sim, 'b ... x y -> b (...) x y')
# unfold the similarity scores, with float(-inf) as padding value
mask_value = -torch.finfo(sim.dtype).max
sim = F.pad(sim, ((wsz // 2,) * 4), value = mask_value)
sim = F.unfold(sim, kernel_size = wsz)
sim = rearrange(sim, 'b (h l j) (x y) -> b h l j x y', h = heads, l = num_queries, x = height, y = width)
# rel pos bias
sim = sim + self.rel_pos_bias
# numerically stable attention
sim = sim - sim.amax(dim = -3, keepdim = True).detach()
attn = sim.softmax(dim = -3)
# unfold values
v = F.pad(v, ((wsz // 2,) * 4), value = 0.)
v = F.unfold(v, kernel_size = wsz)
v = rearrange(v, 'b (d j) (x y) -> b d j x y', d = dim_head, x = height, y = width)
# aggregate values
out = einsum('b h l j x y, b d j x y -> b l h d x y', attn, v)
# combine heads
out = rearrange(out, 'b l h d x y -> (b l) (h d) x y')
out = self.to_out(out)
out = rearrange(out, '(b l) d x y -> b l d x y', b = batch)
# return original input if one query
if is_one_query:
out = rearrange(out, 'b 1 ... -> b ...')
return out
class QueryAttnUpsample(nn.Module):
def __init__(self, dim, **kwargs):
super().__init__()
self.norm = LayerNormChan(dim)
self.qna = QueryAndAttend(dim = dim, num_queries = 4, **kwargs)
def forward(self, x):
x = self.norm(x)
out = self.qna(x)
out = rearrange(out, 'b (w1 w2) c h w -> b c (h w1) (w w2)', w1 = 2, w2 = 2)
return out
# vqgan attention layer
class VQGanAttention(nn.Module): class VQGanAttention(nn.Module):
def __init__( def __init__(
self, self,
@@ -393,28 +287,6 @@ class VQGanAttention(nn.Module):
return out + residual return out + residual
class NullVQGanVAE(nn.Module):
def __init__(
self,
*,
channels
):
super().__init__()
self.encoded_dim = channels
self.layers = 0
def get_encoded_fmap_size(self, size):
return size
def copy_for_eval(self):
return self
def encode(self, x):
return x
def decode(self, x):
return x
class VQGanVAE(nn.Module): class VQGanVAE(nn.Module):
def __init__( def __init__(
self, self,
@@ -481,7 +353,7 @@ class VQGanVAE(nn.Module):
for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn): for layer_index, (dim_in, dim_out), layer_num_resnet_blocks, layer_use_attn in zip(range(layers), dim_pairs, num_resnet_blocks, use_attn):
append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu())) append(self.encoders, nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride = 2, padding = 1), leaky_relu()))
prepend(self.decoders, nn.Sequential(QueryAttnUpsample(dim_out), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu())) prepend(self.decoders, nn.Sequential(nn.Upsample(scale_factor = 2, mode = 'bilinear', align_corners = False), nn.Conv2d(dim_out, dim_in, 3, padding = 1), leaky_relu()))
if layer_use_attn: if layer_use_attn:
prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout)) prepend(self.decoders, VQGanAttention(dim = dim_out, heads = attn_heads, dim_head = attn_dim_head, dropout = attn_dropout))
@@ -535,9 +407,6 @@ class VQGanVAE(nn.Module):
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
def get_encoded_fmap_size(self, image_size):
return image_size // (2 ** self.layers)
def copy_for_eval(self): def copy_for_eval(self):
device = next(self.parameters()).device device = next(self.parameters()).device
vae_copy = copy.deepcopy(self.cpu()) vae_copy = copy.deepcopy(self.cpu())
@@ -583,8 +452,7 @@ class VQGanVAE(nn.Module):
img, img,
return_loss = False, return_loss = False,
return_discr_loss = False, return_discr_loss = False,
return_recons = False, return_recons = False
add_gradient_penalty = True
): ):
batch, channels, height, width, device = *img.shape, img.device batch, channels, height, width, device = *img.shape, img.device
assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}' assert height == self.image_size and width == self.image_size, 'height and width of input image must be equal to {self.image_size}'
@@ -609,11 +477,11 @@ class VQGanVAE(nn.Module):
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
gp = gradient_penalty(img, img_discr_logits)
discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
if add_gradient_penalty: loss = discr_loss + gp
gp = gradient_penalty(img, img_discr_logits)
loss = discr_loss + gp
if return_recons: if return_recons:
return loss, fmap return loss, fmap

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.44', version = '0.0.37',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',