Compare commits

...

20 Commits

Author SHA1 Message Date
Phil Wang
23c401a5d5 use the eval decorator 2022-04-14 10:13:43 -07:00
Phil Wang
68e9883f59 use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well 2022-04-14 10:10:04 -07:00
Phil Wang
95b018374a start using swish glu everywhere, given success of PaLM 2022-04-14 09:34:32 -07:00
Phil Wang
8b5c2385b0 better naming 2022-04-14 09:24:31 -07:00
Phil Wang
f2c52d8239 fix bug with classifier free guidance for prior network, even though it seems it may not be used 2022-04-14 09:21:51 -07:00
Phil Wang
97e951221b bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out 2022-04-14 09:16:09 -07:00
Phil Wang
e1b0c140f1 cleanup readme 2022-04-14 08:51:22 -07:00
Phil Wang
5989569a44 link to OpenCLIP effort 2022-04-14 08:31:15 -07:00
Phil Wang
82464d7bd3 per-fect 2022-04-14 08:30:07 -07:00
Phil Wang
7fb3f695d5 offer continuously parameterized time embedding for diffusion prior network, remove a hyperparameter that may trip up people, if not set correctly 2022-04-14 08:28:11 -07:00
Phil Wang
7e93b9d3c8 make sure classifier free guidance condition scaling is exposed on DALLE2 forward function 2022-04-13 20:14:28 -07:00
Phil Wang
4c827ba94f typo 2022-04-13 19:01:03 -07:00
Phil Wang
cb3923a90f readme tweak 2022-04-13 18:43:34 -07:00
Phil Wang
cc30676a3f lengthen todo 2022-04-13 18:34:09 -07:00
Phil Wang
c7fb327618 link to x-clip 2022-04-13 18:26:30 -07:00
Phil Wang
14ddbc159c cleanup 2022-04-13 18:24:32 -07:00
Phil Wang
0692f1699f favorite quote 2022-04-13 18:17:59 -07:00
Phil Wang
26c4534bc3 readme 2022-04-13 18:11:55 -07:00
Phil Wang
5e06cde4cb always work in the l2normed space for image and text embeddings 2022-04-13 18:08:42 -07:00
Phil Wang
a1a8a78f21 fix everything and make sure it runs end to end, document everything in readme for public 2022-04-13 18:05:25 -07:00
4 changed files with 576 additions and 123 deletions

291
README.md
View File

@@ -22,9 +22,283 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch
```
## Usage (work in progress)
## Usage
<a href="https://github.com/lucidrains/big-sleep">template</a>
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters
```python
import torch
from dalle2_pytorch import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
use_all_token_embeds = True, # whether to use fine-grained contrastive learning (FILIP)
decoupled_contrastive_learning = True, # use decoupled contrastive learning (DCL) objective function, removing positive pairs from the denominator of the InfoNCE loss (CLOOB + DCL)
extra_latent_projection = True, # whether to use separate projections for text-to-image vs image-to-text comparisons (CLOOB)
use_visual_ssl = True, # whether to do self supervised learning on iages
visual_ssl_type = 'simclr', # can be either 'simclr' or 'simsiam', depending on using DeCLIP or SLIP
use_mlm = False, # use masked language learning (MLM) on text (DeCLIP)
text_ssl_loss_weight = 0.05, # weight for text MLM loss
image_ssl_loss_weight = 0.05 # weight for image self-supervised learning loss
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train
loss = clip(
text,
images,
return_loss = True # needs to be set to True to return contrastive loss
)
loss.backward()
# do the above with as many texts and images as possible in a loop
```
Then, you will need to train the decoder, which learns to generate images based on the image embedding coming from the trained CLIP above
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# unet for the decoder
unet = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into decoder
loss = decoder(images)
loss.backward()
# do the above for many many many many steps
# then it will learn to generate images based on the CLIP image embeddings
```
Finally, the main contribution of the paper. The repository offers the diffusion prior network. It takes the CLIP text embeddings and tries to generate the CLIP image embeddings. Again, you will need the trained CLIP from the first step
```python
import torch
from dalle2_pytorch import DiffusionPriorNetwork, DiffusionPrior, CLIP
# get trained CLIP from step one
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8,
).cuda()
# setup prior network, which contains an autoregressive transformer
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
# diffusion prior network, which contains the CLIP and network (with transformer) above
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed text and images into diffusion prior network
loss = diffusion_prior(text, images)
loss.backward()
# do the above for many many many steps
# now the diffusion prior can generate image embeddings from the text embeddings
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
```python
from dalle2_pytorch import DALLE2
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
# send the text as a string if you want to use the simple tokenizer from DALLE v1
# or you can do it as token ids, if you have your own tokenizer
texts = ['glistening morning dew on a flower petal']
images = dalle2(texts) # (1, 3, 256, 256)
```
That's it!
Let's see the whole script below
```python
import torch
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# mock data
text = torch.randint(0, 49408, (4, 256)).cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# train
loss = clip(
text,
images,
return_loss = True
)
loss.backward()
# do above for many steps ...
# prior networks (with transformer)
prior_network = DiffusionPriorNetwork(
dim = 512,
depth = 6,
dim_head = 64,
heads = 8
).cuda()
diffusion_prior = DiffusionPrior(
net = prior_network,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = diffusion_prior(text, images)
loss.backward()
# do above for many steps ...
# decoder (with unet)
unet = Unet(
dim = 128,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
decoder = Decoder(
net = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss.backward()
# do above for many steps
dalle2 = DALLE2(
prior = diffusion_prior,
decoder = decoder
)
images = dalle2(
['cute puppy chasing after a squirrel'],
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
```
Everything in this readme should run without error
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
@@ -32,22 +306,25 @@ $ dream 'sharing a sunset at the summit of mount everest with my dog'
Once built, images will be saved to the same directory the command is invoked
## Training (work in progress, will offer both in code and as command-line)
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
Todo
## Todo
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [ ] make sure it works end to end to produce an output tensor, taking a single gradient step
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
## Citations
@@ -95,3 +372,5 @@ Todo
primaryClass = {cs.LG}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1 +1,2 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from x_clip import CLIP

View File

@@ -1,10 +1,19 @@
import tqdm
import math
from tqdm import tqdm
from inspect import isfunction
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
from einops_exts import rearrange_many, repeat_many
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters import filter2d
from dalle2_pytorch.tokenizer import tokenizer
# use x-clip
@@ -16,7 +25,9 @@ def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
if exists(val):
return val
return d() if isfunction(d) else d
def eval_decorator(fn):
def inner(model, *args, **kwargs):
@@ -27,6 +38,11 @@ def eval_decorator(fn):
return out
return inner
def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
return all([type(el) == str for el in x])
# for controlling freezing of CLIP
def set_module_requires_grad_(module, requires_grad):
@@ -43,6 +59,11 @@ def freeze_model_and_make_eval_(model):
model.eval()
freeze_all_layers_(model)
# tensor helpers
def l2norm(t):
return F.normalize(t, dim = -1)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -91,29 +112,83 @@ class RMSNorm(nn.Module):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * self.gamma * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
class ChanRMSNorm(RMSNorm):
def forward(self, x):
squared_sum = (x ** 2).sum(dim = 1, keepdim = True)
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.norm = RMSNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
return self.fn(x, **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# feedforward
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
# attention
class Attention(nn.Module):
def __init__(
self,
*,
dim,
*,
dim_head = 64,
heads = 8,
dropout = 0.,
@@ -121,6 +196,7 @@ class Attention(nn.Module):
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
self.causal = causal
@@ -128,17 +204,17 @@ class Attention(nn.Module):
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_qkv = nn.Linear(dim, inner_dim, bias = False)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
b, n, device = x.shape[:2], x.device
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
q = rearrange(q, 'b n (h d) -> b h n d')
q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
@@ -148,7 +224,7 @@ class Attention(nn.Module):
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j')
sim = einsum('b h i d, b j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
@@ -157,11 +233,13 @@ class Attention(nn.Module):
sim = sim.masked_fill(~mask, max_neg_value)
if self.causal:
causal_mask = torch.ones((n, n), dtype = torch.bool, device = device).triu(1)
i, j = sim.shape[-2:]
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -208,26 +286,26 @@ class DiffusionPriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps = 1000,
num_timesteps = None,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(**kwargs)
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(*args, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -247,10 +325,18 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 4), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
@@ -268,7 +354,7 @@ class DiffusionPriorNetwork(nn.Module):
# classifier free guidance
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
@@ -288,19 +374,20 @@ class DiffusionPrior(nn.Module):
*,
clip,
timesteps = 1000,
cond_prob_drop = 0.2,
cond_drop_prob = 0.2,
loss_type = 'l1',
predict_x0 = True
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_prob_drop = cond_prob_drop
self.cond_drop_prob = cond_drop_prob
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -345,12 +432,13 @@ class DiffusionPrior(nn.Module):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
return l2norm(image_embed)
def get_text_cond(self, text):
text_encodings = self.clip.text_transformer(text)
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t):
@@ -389,7 +477,7 @@ class DiffusionPrior(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, text_cond = None, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, text_cond = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
@@ -420,18 +508,18 @@ class DiffusionPrior(nn.Module):
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embeds']
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r')
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b d', d = image_embed_dim)
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return top_image_embeds
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -442,14 +530,14 @@ class DiffusionPrior(nn.Module):
)
def p_losses(self, image_embed, t, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.q_sample(x_start = image_embed, t = t, noise = noise)
x_recon = self.net(
image_embed_noisy,
t,
cond_prob_drop = self.cond_prob_drop,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
@@ -472,7 +560,7 @@ class DiffusionPrior(nn.Module):
image_embed = self.get_image_embed(image)
text_cond = self.get_text_cond(text)
loss = self.p_losses(x, times, image_embed = image_embed, text_cond = text_cond, *args, **kwargs)
loss = self.p_losses(image_embed, times, text_cond = text_cond, *args, **kwargs)
return loss
# decoder
@@ -483,6 +571,17 @@ def Upsample(dim):
def Downsample(dim):
return nn.Conv2d(dim, dim, 4, 2, 1)
class Blur(nn.Module):
def __init__(self):
super().__init__()
filt = torch.Tensor([1, 2, 1])
self.register_buffer('filt', filt)
def forward(self, x):
filt = self.filt
filt = rearrange(filt, '... j -> ... 1 j') * rearrange(flit, '... i -> ... i 1')
return filter2d(x, filt, normalized = True)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
@@ -510,16 +609,23 @@ class ConvNextBlock(nn.Module):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
self.net = nn.Sequential(
RMSNorm(dim) if norm else nn.Identity(),
ChanRMSNorm(dim) if norm else nn.Identity(),
nn.Conv2d(dim, inner_dim, 3, padding = 1),
nn.GELU(),
nn.Conv2d(inner_dim, dim_out, 3, padding = 1)
@@ -530,28 +636,73 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
if exists(self.cross_attn):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class EinopsToAndFrom(nn.Module):
def __init__(self, from_einops, to_einops, fn):
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.from_einops = from_einops
self.to_einops = to_einops
self.fn = fn
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
def forward(self, x, **kwargs):
shape = x.shape
reconstitute_kwargs = dict(tuple(zip(self.from_einops.split(' '), shape)))
x = rearrange(x, f'{self.from_einops} -> {self.to_einops}')
x = self.fn(x, **kwargs)
x = rearrange(x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs)
return x
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module):
def __init__(
@@ -559,7 +710,8 @@ class Unet(nn.Module):
dim,
*,
image_embed_dim,
time_dim = None,
cond_dim = None,
num_image_tokens = 4,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
@@ -570,18 +722,28 @@ class Unet(nn.Module):
dims = [channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
# time and image embeddings
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
@@ -591,14 +753,15 @@ class Unet(nn.Module):
is_last = ind >= (num_resolutions - 1)
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim)))
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -618,16 +781,16 @@ class Unet(nn.Module):
def forward_with_cond_scale(
self,
x,
*,
*args,
cond_scale = 1.,
**kwargs
):
if cond_scale == 1:
return self.forward(x, **kwargs)
logits = self.forward(*args, **kwargs)
logits = self.forward(x, **kwargs)
null_logits = self.forward(x, cond_prob_drop = 1., **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -637,39 +800,42 @@ class Unet(nn.Module):
*,
image_embed,
text_encodings = None,
cond_prob_drop = 0.
cond_drop_prob = 0.
):
t = self.time_mlp(time)
batch_size, device = x.shape[0], x.device
time_tokens = self.time_mlp(time)
cond_prob_mask = prob_mask_like((batch_size,), cond_prob_drop, device = device)
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
rearrange(cond_prob_mask, 'b -> b 1 1'),
image_tokens,
self.null_image_embed
)
cond = torch.cat((t, image_embed), dim = -1)
c = torch.cat((time_tokens, image_tokens), dim = -2) # c for condition
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.attn(x)
x = self.mid_block2(x, t)
x = self.mid_block1(x, c)
x = self.mid_attn(x)
x = self.mid_block2(x, c)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
@@ -681,17 +847,18 @@ class Decoder(nn.Module):
*,
clip,
timesteps = 1000,
cond_prob_drop = 0.2,
cond_drop_prob = 0.2,
loss_type = 'l1'
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.channels = clip.image_channels
self.image_size = clip.image_size
self.cond_prob_drop = cond_prob_drop
self.cond_drop_prob = cond_drop_prob
betas = cosine_beta_schedule(timesteps)
@@ -733,7 +900,7 @@ class Decoder(nn.Module):
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
return image_embed
return l2norm(image_embed)
def q_mean_variance(self, x_start, t):
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
@@ -756,8 +923,8 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net(x, t, image_embed = image_embed))
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -766,31 +933,31 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed):
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed)
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
return img
@torch.no_grad()
def sample(self, image_embed):
def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed)
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -800,7 +967,7 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, image_embed, t, noise = None):
def p_losses(self, x_start, t, *, image_embed, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
@@ -809,7 +976,7 @@ class Decoder(nn.Module):
x_noisy,
t,
image_embed = image_embed,
cond_prob_drop = self.cond_prob_drop
cond_drop_prob = self.cond_drop_prob
)
if self.loss_type == 'l1':
@@ -821,14 +988,14 @@ class Decoder(nn.Module):
return loss
def forward(self, image, *args, **kwargs):
def forward(self, image):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(x, times, image_embed = image_embed, *args, **kwargs)
loss = self.p_losses(image, times, image_embed = image_embed)
return loss
# main class
@@ -839,23 +1006,28 @@ class DALLE2(nn.Module):
*,
prior,
decoder,
tokenizer = None
prior_num_samples = 2
):
super().__init__()
assert isinstance(prior), DiffusionPrior
assert isinstance(decoder), Decoder
self.tokenizer = tokenizer
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
@torch.no_grad()
@eval_decorator
def forward(
self,
*,
text
text,
cond_scale = 1.
):
if isinstance(text, str):
assert exists(self.tokenizer), 'tokenizer must be passed in if you were to pass in the text as a string'
text = self.tokenizer.encode(text)
device = next(self.parameters()).device
image_embed = prior.sample(text, num_samples_per_batch = 2)
images = decoder.sample(image_embed)
if isinstance(text, str) or is_list_str(text):
text = [text] if not isinstance(text, (list, tuple)) else text
text = tokenizer.tokenize(text).to(device)
image_embed = self.prior.sample(text, num_samples_per_batch = self.prior_num_samples)
images = self.decoder.sample(image_embed, cond_scale = cond_scale)
return images

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.3',
version = '0.0.12',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -24,7 +24,8 @@ setup(
install_requires=[
'click',
'einops>=0.4',
'einops-exts',
'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow',
'torch>=1.10',
'torchvision',