Compare commits

..

48 Commits
0.0.1 ... 0.0.7

Author SHA1 Message Date
Phil Wang
7e93b9d3c8 make sure classifier free guidance condition scaling is exposed on DALLE2 forward function 2022-04-13 20:14:28 -07:00
Phil Wang
4c827ba94f typo 2022-04-13 19:01:03 -07:00
Phil Wang
cb3923a90f readme tweak 2022-04-13 18:43:34 -07:00
Phil Wang
cc30676a3f lengthen todo 2022-04-13 18:34:09 -07:00
Phil Wang
c7fb327618 link to x-clip 2022-04-13 18:26:30 -07:00
Phil Wang
14ddbc159c cleanup 2022-04-13 18:24:32 -07:00
Phil Wang
0692f1699f favorite quote 2022-04-13 18:17:59 -07:00
Phil Wang
26c4534bc3 readme 2022-04-13 18:11:55 -07:00
Phil Wang
5e06cde4cb always work in the l2normed space for image and text embeddings 2022-04-13 18:08:42 -07:00
Phil Wang
a1a8a78f21 fix everything and make sure it runs end to end, document everything in readme for public 2022-04-13 18:05:25 -07:00
Phil Wang
e5e415297c prepare non-causal attention, for use in the unet in the decoder 2022-04-13 12:04:09 -07:00
Phil Wang
c9377efc93 go for the multi-headed queries, one-headed key/values, proven out in AlphaCode as well as PaLM by now 2022-04-13 12:01:43 -07:00
Phil Wang
2a424b6a28 readme 2022-04-13 10:58:06 -07:00
Phil Wang
d3cded3c6c complete logic in diffusion prior for sampling more than 1 image embeds, taking top similarity 2022-04-13 10:52:31 -07:00
Phil Wang
d573c82f8c add one full attention at the middle of the unet, prepare to do efficient attention employing every trick i know from vision transformer literature 2022-04-13 10:39:06 -07:00
Phil Wang
3aa6f91e7a be transparent 2022-04-13 10:32:11 -07:00
Phil Wang
1bf071af78 allow for predicting image embedding directly during diffusion training. need to fix sampling still 2022-04-13 10:29:29 -07:00
Phil Wang
9f1fe6c7ae update todo 2022-04-13 10:09:08 -07:00
Phil Wang
791d27326a add diffusion code for the image embedding. nearly all the code is there except for the cascading ddpm in the decoder (with upscaling etc) 2022-04-13 10:06:52 -07:00
Phil Wang
6d4e9c97bf todo 2022-04-12 20:50:29 -07:00
Phil Wang
40140b54d6 put on project manager hat 2022-04-12 17:51:23 -07:00
Phil Wang
33d69d3859 take care of DDPM decoder (DDPM for producing image embedding will have a separate objective, predicting directly the embedding rather than the noise [epsilon in paper]) 2022-04-12 17:48:41 -07:00
Phil Wang
862e5ba50e more sketches to base dalle2 class 2022-04-12 17:31:01 -07:00
Phil Wang
25d980ebbf complete naive conditioning of unet with image embedding, with ability to dropout for classifier free guidance 2022-04-12 17:27:39 -07:00
Phil Wang
d546a615c0 complete helper methods for doing condition scaling (classifier free guidance), for decoder unet and prior network 2022-04-12 16:11:16 -07:00
Phil Wang
d4c8373635 complete conditional dropout mask creation for both prior network as well as image decoder unet for classifier free guidance 2022-04-12 14:04:08 -07:00
Phil Wang
c814b2b278 sponsor project button 2022-04-12 13:34:02 -07:00
Phil Wang
74aec9d8ca further prepare attention for classifier free guidance 2022-04-12 13:01:18 -07:00
Phil Wang
7647be2569 prep for classifier free guidance for the image embedding diffusion step, even though not mentioned in paper 2022-04-12 12:57:09 -07:00
Phil Wang
59b8abe09e prepare unet to be conditioned on image embedding, optionally text encodings, and reminder for self to build conditional dropout for classifier free guidance 2022-04-12 12:38:56 -07:00
Phil Wang
46dde54948 for integration of X-CLIP automagically in the gaussian diffusion classes 2022-04-12 12:17:34 -07:00
Phil Wang
40aa304b7e rename to DiffusionPriorNetwork in case ARPriorNetwork is ever built 2022-04-12 11:45:57 -07:00
Phil Wang
fd38eb83c4 complete the main contribution of the paper, the diffusion prior network, minus the diffusion training setup 2022-04-12 11:43:59 -07:00
Phil Wang
83aabd42ca move epsilon inside of square root for further stability in rmsnorm
improvise and use rmsnorm in convnext blocks too
2022-04-12 11:18:36 -07:00
Phil Wang
cf22affcbb bring in modified unet using convnext blocks https://arxiv.org/abs/2201.03545 2022-04-12 10:58:44 -07:00
Phil Wang
522f42f582 start using RMSNorm, used in Gopher and AlphaCode, and as a way to go complete bias-less (purportedly more stable according to PaLM) 2022-04-12 10:45:03 -07:00
Phil Wang
0a60818965 dropouts in transformer, also prep for classifier free guidance in decoder 2022-04-12 10:42:57 -07:00
Phil Wang
604765b563 readme 2022-04-12 10:35:56 -07:00
Phil Wang
7bbc62f3d5 bring in pillow, for image encoding to and from 2022-04-12 10:29:55 -07:00
Phil Wang
771fe0d0d2 also consider accepting tokenizer, so dalle2 forward pass can just be invoked as DALLE2(<prompt string>) 2022-04-12 10:29:29 -07:00
Phil Wang
de75a8af76 link to yannic, since he is the best 2022-04-12 10:27:01 -07:00
Phil Wang
df4dac4f5a bring in attention - it is all we need 2022-04-12 10:23:07 -07:00
Phil Wang
24b428bdfc readme 2022-04-12 10:12:42 -07:00
Phil Wang
2ab042b862 create the eventual dream cli, like bigsleep library 2022-04-12 10:04:17 -07:00
Phil Wang
b93ad8b7a2 add cli file, use click 2022-04-12 09:58:53 -07:00
Phil Wang
f5e0aea140 get ready for CLI tool, just like stylegan2_pytorch 2022-04-12 09:57:54 -07:00
Phil Wang
5e03b7f932 get ready for all the training related classes and functions 2022-04-12 09:54:50 -07:00
Phil Wang
62c0d321a6 sketch 2022-04-12 09:39:42 -07:00
7 changed files with 1170 additions and 30 deletions

1
.github/FUNDING.yml vendored Normal file
View File

@@ -0,0 +1 @@
github: [lucidrains]

335
README.md
View File

@@ -2,7 +2,7 @@
## DALL-E 2 - Pytorch (wip)
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
@@ -12,6 +12,318 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
## Install
```bash
$ pip install dalle2-pytorch
```
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
This repository will demonstrate integration with `x-clip` for starters
```python
import torch
from dalle2_pytorch import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train
loss = clip(
text,
images,
return_loss = True # needs to be set to True to return contrastive loss
)
loss.backward()
# do the above with as many texts and images as possible in a loop
```
Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# unet for the decoder
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder
loss = decoder(images)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings
```
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(text, images)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
```python
from dalle2_pytorch import DALLE2
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer
texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)
```
That's it!
Let's see the whole script below
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train
loss = clip(
text,
images,
return_loss = True
)
loss.backward()
# do above for many steps ...
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
```
Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
## Citations
```bibtex
@@ -39,3 +351,24 @@ Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord
primaryClass = {cs.CV}
}
```
```bibtex
@inproceedings{Liu2022ACF,
title = {A ConvNet for the 2020s},
author = {Zhuang Liu and Hanzi Mao and Chaozheng Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
year = {2022}
}
```
```bibtex
@misc{zhang2019root,
title = {Root Mean Square Layer Normalization},
author = {Biao Zhang and Rico Sennrich},
year = {2019},
eprint = {1910.07467},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1 +1,2 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from x_clip import CLIP

9
dalle2_pytorch/cli.py Normal file
View File

@@ -0,0 +1,9 @@
import click
def main():
pass
@click.command()
@click.argument('text')
def dream(text):
return image

View File

@@ -1,7 +1,16 @@
import math
from tqdm import tqdm
from inspect import isfunction
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from einops import rearrange, repeat
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from dalle2_pytorch.tokenizer import tokenizer
# use x-clip
@@ -13,7 +22,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if isfunction(d) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
@@ -24,6 +35,11 @@ def eval_decorator(fn):
return out
return inner
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
return all([type(el) == str for el in x])
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -36,44 +52,804 @@ def freeze_all_layers_(module):
def unfreeze_all_layers_(module):
set_module_requires_grad_(module, True)
def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# tensor helpers
def l2norm(t):
return F.normalize(t, dim = -1)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# gaussian diffusion helper functions
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
# diffusion prior
class RMSNorm(nn.Module):
def __init__(self, dim, eps = 1e-5):
super().__init__()
self.eps = eps
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim))
def forward(self, x):
squared_sum = (x ** 2).sum(dim = -1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
class Attention(nn.Module):
def __init__(
self,
dim,
*,
dim_head = 64,
heads = 8,
dropout = 0.,
causal = False
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.causal = causal
self.norm = RMSNorm(dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
if self.causal:
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class CausalTransformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
norm_out = False,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
# todo - bring in rotary embeddings or alibi
self.layers = nn.ModuleList([])
for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, causal = True, dim_head = dim_head, heads = heads, dropout = attn_dropout),
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))
self.norm = RMSNorm(dim) if norm_out else nn.Identity() # unclear in paper whether they projected after the classic layer norm for the final denoised image embedding, or just had the transformer output it directly: plan on offering both options
def forward(
self,
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = ff(x) + x
return self.norm(x)
class DiffusionPriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps = 1000,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
image_embed,
diffusion_timesteps,
*,
text_encodings,
text_embed,
mask = None,
cond_drop_prob = 0.2
):
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
tokens = torch.cat((
text_encodings,
text_embed,
time_embed,
learned_queries
), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
tokens = self.causal_transformer(tokens, mask = mask)
# get learned query, which should predict the image embedding (per DDPM timestep)
pred_image_embed = tokens[..., -1, :]
return pred_image_embed
class DiffusionPrior(nn.Module):
def __init__(
self,
net,
*,
clip
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
predict_x0 = True
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
if self.predict_x0:
x_recon = self.net(x, t, **text_cond)
# not 100% sure of this above line - for any spectators, let me know in the github issues (or through a pull request) if you know how to correctly do this
# i'll be rereading https://arxiv.org/abs/2111.14822, where i think a similar approach is taken
else:
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, **text_cond))
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, text_cond):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, image_embed, t, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
x_recon = self.net(
image_embed_noisy,
t,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
to_predict = noise if not self.predict_x0 else image_embed
if self.loss_type == 'l1':
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
# decoder
def Upsample(dim):
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim = -1)
class ConvNextBlock(nn.Module):
""" https://arxiv.org/abs/2201.03545 """
def __init__(
self,
dim,
dim_out,
*,
cond_dim = None,
mult = 2,
norm = True
):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
ChanRMSNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
)
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.net(h)
return h + self.res_conv(x)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
time_dim = None,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
):
super().__init__()
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
cond_dim = time_dim + image_embed_dim
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity()
]))
out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
ConvNextBlock(dim, dim),
nn.Conv2d(dim, out_dim, 1)
)
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
self,
x,
time,
*,
text,
image
image_embed,
text_encodings = None,
cond_drop_prob = 0.
):
return text
batch_size, device = x.shape[0], x.device
t = self.time_mlp(time)
# decoder
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
)
t = torch.cat((t, image_embed), dim = -1)
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = upsample(x)
return self.final_conv(x)
class Decoder(nn.Module):
def __init__(
self,
net,
*,
clip,
prior
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1'
):
super().__init__()
assert isinstance(clip, CLIP)
assert isinstance(prior, DiffusionPrior)
freeze_model_and_make_eval_(clip)
self.clip = clip
def forward(
self,
*,
image
):
return image
self.net = net
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob
betas = cosine_beta_schedule(timesteps)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
self.register_buffer('betas', betas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_image_embed(self, image):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return l2norm(image_embed)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
return mean, variance, log_variance
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = self.net(
x_noisy,
t,
image_embed = image_embed,
cond_drop_prob = self.cond_drop_prob
)
if self.loss_type == 'l1':
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, image):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed)
return loss
# main class
@@ -81,19 +857,29 @@ class DALLE2(nn.Module):
def __init__(
self,
*,
clip,
prior,
decoder
decoder,
prior_num_samples = 2
):
super().__init__()
assert isinstance(clip), CLIP
assert isinstance(prior), DiffusionPrior
assert isinstance(decoder), Decoder
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior.eval()
self.decoder = decoder.eval()
self.prior_num_samples = prior_num_samples
@torch.no_grad()
def forward(
self,
*,
text
text,
cond_scale = 1.
):
return text
device = next(self.parameters()).device
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images

0
dalle2_pytorch/train.py Normal file
View File

View File

@@ -4,7 +4,13 @@ setup(
name = 'dalle2-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.0.1',
entry_points={
'console_scripts': [
'dalle2_pytorch = dalle2_pytorch.cli:main',
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.7',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -16,11 +22,15 @@ setup(
'text to image'
],
install_requires=[
'click',
'einops>=0.4',
'einops-exts',
'torch>=1.6',
'x-clip>=0.4.1',
'yttm'
'einops-exts>=0.0.3',
'pillow',
'torch>=1.10',
'torchvision',
'tqdm',
'x-clip>=0.4.4',
'youtokentome'
],
classifiers=[
'Development Status :: 4 - Beta',