Compare commits

...

27 Commits

Author SHA1 Message Date
Phil Wang
5a731cc936 make kernel size and sigma for gaussian blur for cascading DDPM overridable at forward. also make sure unets are wrapped in a modulelist so that at sample time, blurring does not happen 2022-04-18 12:00:47 -07:00
Phil Wang
6cddefad26 readme 2022-04-18 11:52:25 -07:00
Phil Wang
0332eaa6ff complete first pass at full cascading DDPM setup in Decoder, flexible enough to support one unet for testing 2022-04-18 11:44:56 -07:00
Phil Wang
1cce4225eb 0.0.18 2022-04-17 07:29:34 -07:00
Phil Wang
5ab0700bab Merge pull request #14 from kashif/loss-schedule
added huber loss and other schedulers
2022-04-17 07:29:10 -07:00
Kashif Rasul
b0f2fbaa95 schedule to Prior 2022-04-17 15:21:47 +02:00
Kashif Rasul
51361c2d15 added beta_schedule argument 2022-04-17 15:19:33 +02:00
Kashif Rasul
42d6e47387 added huber loss and other schedulers 2022-04-17 15:14:05 +02:00
Phil Wang
1e939153fb link to AssemblyAI explanation 2022-04-15 12:58:57 -07:00
Phil Wang
1abeb8918e personal project management for next week 2022-04-15 08:04:01 -07:00
Phil Wang
b423855483 commit to jax version 2022-04-15 07:16:25 -07:00
Phil Wang
c400d8758c prepare for cascading diffusion in unet, save the full progressive upsampling architecture to be built next week 2022-04-15 07:03:28 -07:00
Phil Wang
bece206699 fix bug thanks to @jihoonerd 2022-04-15 06:44:40 -07:00
Phil Wang
5b4ee09625 ideation 2022-04-14 13:48:01 -07:00
Phil Wang
6e27f617f1 use t5 relative positional bias in prior network causal transformer, since it makes more sense than rotary embeddings 2022-04-14 12:01:09 -07:00
Phil Wang
9f55c24db6 allow for decoder conditioning with the text encodings from CLIP, if it is passed in. use lazy linear to avoid researchers having to worry about text encoding dimensions, but remove later if it does not work well 2022-04-14 11:46:45 -07:00
Phil Wang
69e822b7f8 "project management" 2022-04-14 10:20:37 -07:00
Phil Wang
23c401a5d5 use the eval decorator 2022-04-14 10:13:43 -07:00
Phil Wang
68e9883f59 use cross attention for conditioning unet based on image embedding tokens (which opens up the door on conditioning on text encodings as well 2022-04-14 10:10:04 -07:00
Phil Wang
95b018374a start using swish glu everywhere, given success of PaLM 2022-04-14 09:34:32 -07:00
Phil Wang
8b5c2385b0 better naming 2022-04-14 09:24:31 -07:00
Phil Wang
f2c52d8239 fix bug with classifier free guidance for prior network, even though it seems it may not be used 2022-04-14 09:21:51 -07:00
Phil Wang
97e951221b bring in blur, as it will be used somewhere in the cascading DDPM in the decoder eventually, once i figure it out 2022-04-14 09:16:09 -07:00
Phil Wang
e1b0c140f1 cleanup readme 2022-04-14 08:51:22 -07:00
Phil Wang
5989569a44 link to OpenCLIP effort 2022-04-14 08:31:15 -07:00
Phil Wang
82464d7bd3 per-fect 2022-04-14 08:30:07 -07:00
Phil Wang
7fb3f695d5 offer continuously parameterized time embedding for diffusion prior network, remove a hyperparameter that may trip up people, if not set correctly 2022-04-14 08:28:11 -07:00
3 changed files with 606 additions and 146 deletions

153
README.md
View File

@@ -2,7 +2,9 @@
## DALL-E 2 - Pytorch (wip)
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch. <a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a>
Implementation of <a href="https://openai.com/dall-e-2/">DALL-E 2</a>, OpenAI's updated text-to-image synthesis neural network, in Pytorch.
<a href="https://youtu.be/RJwPN4qNi_Y?t=555">Yannic Kilcher summary</a> | <a href="https://www.youtube.com/watch?v=F1X4fHzF4mQ">AssemblyAI explainer</a>
The main novelty seems to be an extra layer of indirection with the prior network (whether it is an autoregressive transformer or a diffusion network), which predicts an image embedding based on the text embedding from CLIP. Specifically, this repository will only build out the diffusion prior network, as it is the best performing variant (but which incidentally involves a causal transformer as the denoising network 😂)
@@ -12,9 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
Do let me know if anyone is interested in a Jax version https://github.com/lucidrains/DALLE2-pytorch/discussions/8
For all of you emailing me (there is a lot), the best way to contribute is through pull requests. Everything is open sourced after all. All my thoughts are public. This is your moment to participate.
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
## Install
@@ -22,19 +22,11 @@ For all of you emailing me (there is a lot), the best way to contribute is throu
$ pip install dalle2-pytorch
```
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training (for deep learning practitioners)
## Usage
To train DALLE-2 is a 3 step process, with the training of CLIP being the most important
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already underway.
To train CLIP, you can either use <a href="https://github.com/lucidrains/x-clip">x-clip</a> package, or join the LAION discord, where a lot of replication efforts are already <a href="https://github.com/mlfoundations/open_clip">underway</a>.
This repository will demonstrate integration with `x-clip` for starters
@@ -109,7 +101,7 @@ clip = CLIP(
unet = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
@@ -117,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
@@ -162,7 +154,6 @@ clip = CLIP(
prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
@@ -191,6 +182,81 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within this framework as so
```python
import torch
from dalle2_pytorch import Unet, Decoder, CLIP
# trained clip from step 1
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
).cuda()
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 16,
image_embed_dim = 512,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
lowres_cond = True, # subsequent unets must have this turned on (and first unet must have this turned off)
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
# mock images (get a lot of this)
images = torch.randn(4, 3, 512, 512).cuda()
# feed images into decoder, specifying which unet you want to train
# each unet can be trained separately, which is one of the benefits of the cascading DDPM scheme
loss = decoder(images, unet_number = 1)
loss.backward()
loss = decoder(images, unet_number = 2)
loss.backward()
# do the above for many steps for both unets
# then it will learn to generate images based on the CLIP image embeddings
# chaining the unets from lowest resolution to highest resolution (thus cascading)
mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
```python
@@ -251,7 +317,6 @@ loss.backward()
prior_network = DiffusionPriorNetwork(
dim = 512,
num_timesteps = 100,
depth = 6,
dim_head = 64,
heads = 8
@@ -271,23 +336,34 @@ loss.backward()
# decoder (with unet)
unet = Unet(
unet1 = Unet(
dim = 128,
image_embed_dim = 512,
time_dim = 128,
cond_dim = 128,
channels = 3,
dim_mults=(1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
lowres_cond = True
).cuda()
decoder = Decoder(
net = unet,
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
).cuda()
loss = decoder(images)
loss.backward()
for unet_number in (1, 2):
loss = decoder(images, unet_number = unet_number) # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
loss.backward()
# do above for many steps
@@ -301,13 +377,27 @@ images = dalle2(
cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
)
# save your image
# save your image (in this example, of size 256x256)
```
Everything in this readme should run without error
You can also train the decoder on images of greater than the size (say 512x512) at which CLIP was trained (256x256). The images will be resized to CLIP image resolution for the image embeddings
For the layperson, no worries, training will all be automated into a CLI tool, at least for small scale training.
## CLI Usage (work in progress)
```bash
$ dream 'sharing a sunset at the summit of mount everest with my dog'
```
Once built, images will be saved to the same directory the command is invoked
## Training wrapper (wip)
Offer training wrappers
## Training CLI (wip)
<a href="https://github.com/lucidrains/stylegan2-pytorch">template</a>
@@ -317,12 +407,17 @@ For the layperson, no worries, training will all be automated into a CLI tool, a
- [x] finish off gaussian diffusion class for latent embedding - allow for prediction of epsilon
- [x] add what was proposed in the paper, where DDPM objective for image latent embedding predicts x0 directly (reread vq-diffusion paper and get caught up on that line of work)
- [x] make sure it works end to end to produce an output tensor, taking a single gradient step
- [ ] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [ ] look into Jonathan Ho's cascading DDPM for the decoder, as that seems to be what they are using. get caught up on DDPM literature
- [ ] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
- [ ] make unet more configurable
- [ ] figure out some factory methods to make cascading unet instantiations less error-prone
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention
- [ ] figure out the big idea behind latent diffusion and what can be ported over
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
## Citations

View File

@@ -1,15 +1,20 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
# use x-clip
@@ -26,6 +31,9 @@ def default(val, d):
return val
return d() if isfunction(d) else d
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def eval_decorator(fn):
def inner(model, *args, **kwargs):
was_training = model.training
@@ -61,6 +69,15 @@ def freeze_model_and_make_eval_(model):
def l2norm(t):
return F.normalize(t, dim = -1)
def resize_image_to(t, image_size, mode = 'bilinear'): # take a look at https://github.com/assafshocher/ResizeRight
shape = cast_tuple(image_size, 2)
orig_image_size = t.shape[-2:]
if orig_image_size == shape:
return t
return F.interpolate(t, size = shape, mode = mode)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
@@ -95,6 +112,29 @@ def cosine_beta_schedule(timesteps, s = 0.008):
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps)
def quadratic_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start**2, beta_end**2, timesteps) ** 2
def sigmoid_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
betas = torch.linspace(-6, 6, timesteps)
return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
# diffusion prior
class RMSNorm(nn.Module):
@@ -115,25 +155,110 @@ class ChanRMSNorm(RMSNorm):
inv_norm = torch.rsqrt(squared_sum + self.eps)
return x * inv_norm * rearrange(self.gamma, 'c -> 1 c 1 1') * self.scale
class PreNormResidual(nn.Module):
def __init__(self, dim, fn):
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.norm = RMSNorm(dim)
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs) + x
return self.fn(x, **kwargs) + x
# mlp
class MLP(nn.Module):
def __init__(
self,
dim_in,
dim_out,
*,
expansion_factor = 2.,
depth = 2,
norm = False,
):
super().__init__()
hidden_dim = int(expansion_factor * dim_out)
norm_fn = lambda: nn.LayerNorm(hidden_dim) if norm else nn.Identity()
layers = [nn.Sequential(
nn.Linear(dim_in, hidden_dim),
nn.SiLU(),
norm_fn()
)]
for _ in range(depth - 1):
layers.append(nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.SiLU(),
norm_fn()
))
layers.append(nn.Linear(hidden_dim, dim_out))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x.float())
# relative positional bias for causal transformer
class RelPosBias(nn.Module):
def __init__(
self,
heads = 8,
num_buckets = 32,
max_distance = 128,
):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
n = -relative_position
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
return torch.where(is_small, n, val_if_large)
def forward(self, i, j, *, device):
q_pos = torch.arange(i, dtype = torch.long, device = device)
k_pos = torch.arange(j, dtype = torch.long, device = device)
rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
values = self.relative_attention_bias(rp_bucket)
return rearrange(values, 'i j h -> h i j')
# feedforward
class SwiGLU(nn.Module):
""" used successfully in https://arxiv.org/abs/2204.0231 """
def forward(self, x):
x, gate = x.chunk(2, dim = -1)
return x * F.silu(gate)
def FeedForward(dim, mult = 4, dropout = 0., post_activation_norm = False):
""" post-activation norm https://arxiv.org/abs/2110.09456 """
def FeedForward(dim, mult = 4, dropout = 0.):
inner_dim = int(mult * dim)
return nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, inner_dim, bias = False),
nn.GELU(),
nn.Linear(dim, inner_dim * 2, bias = False),
SwiGLU(),
RMSNorm(inner_dim) if post_activation_norm else nn.Identity(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim, bias = False)
)
# attention
class Attention(nn.Module):
def __init__(
self,
@@ -158,7 +283,7 @@ class Attention(nn.Module):
self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, mask = None):
def forward(self, x, mask = None, attn_bias = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
@@ -175,6 +300,14 @@ class Attention(nn.Module):
q = q * self.scale
sim = einsum('b h i d, b j d -> b h i j', q, k)
# relative positional encoding (T5 style)
if exists(attn_bias):
sim = sim + attn_bias
# masking
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
@@ -187,8 +320,13 @@ class Attention(nn.Module):
causal_mask = torch.ones((i, j), dtype = torch.bool, device = device).triu(j - i + 1)
sim = sim.masked_fill(causal_mask, max_neg_value)
# attention
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
attn = self.dropout(attn)
# aggregate values
out = einsum('b h i j, b j d -> b h i d', attn, v)
@@ -209,7 +347,7 @@ class CausalTransformer(nn.Module):
ff_dropout = 0.
):
super().__init__()
# todo - bring in rotary embeddings or alibi
self.rel_pos_bias = RelPosBias(heads = heads)
self.layers = nn.ModuleList([])
for _ in range(depth):
@@ -225,8 +363,12 @@ class CausalTransformer(nn.Module):
x,
mask = None # we will need a mask here, due to variable length of the text encodings - also offer dalle1 strategy with padding token embeddings
):
n, device = x.shape[1], x.device
attn_bias = self.rel_pos_bias(n, n + 1, device = device)
for attn, ff in self.layers:
x = attn(x, mask = mask) + x
x = attn(x, mask = mask, attn_bias = attn_bias) + x
x = ff(x) + x
return self.norm(x)
@@ -235,27 +377,26 @@ class DiffusionPriorNetwork(nn.Module):
def __init__(
self,
dim,
num_timesteps = 1000,
num_timesteps = None,
**kwargs
):
super().__init__()
self.time_embeddings = nn.Embedding(num_timesteps, dim) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.time_embeddings = nn.Embedding(num_timesteps, dim) if exists(num_timesteps) else nn.Sequential(Rearrange('b -> b 1'), MLP(1, dim)) # also offer a continuous version of timestep embeddings, with a 2 layer MLP
self.learned_query = nn.Parameter(torch.randn(dim))
self.causal_transformer = CausalTransformer(dim = dim, **kwargs)
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -275,8 +416,15 @@ class DiffusionPriorNetwork(nn.Module):
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
if exists(mask):
mask = F.pad(mask, (0, 3), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.time_embeddings(diffusion_timesteps)
time_embed = rearrange(time_embed, 'b d -> b 1 d')
@@ -316,10 +464,11 @@ class DiffusionPrior(nn.Module):
net,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1',
predict_x0 = True
timesteps=1000,
cond_drop_prob=0.2,
loss_type="l1",
predict_x0=True,
beta_schedule="cosine",
):
super().__init__()
assert isinstance(clip, CLIP)
@@ -335,11 +484,22 @@ class DiffusionPrior(nn.Module):
self.predict_x0 = predict_x0
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
betas = cosine_beta_schedule(timesteps)
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -439,31 +599,6 @@ class DiffusionPrior(nn.Module):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.no_grad()
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -490,11 +625,39 @@ class DiffusionPrior(nn.Module):
loss = F.l1_loss(to_predict, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(to_predict, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(to_predict, x_recon)
else:
raise NotImplementedError()
return loss
@torch.no_grad()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):
# in the paper, what they did was
# sample 2 image embeddings, choose the top 1 similarity, as judged by CLIP
text = repeat(text, 'b ... -> (b r) ...', r = num_samples_per_batch)
batch_size = text.shape[0]
image_embed_dim = self.image_embed_dim
text_cond = self.get_text_cond(text)
image_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond = text_cond)
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))
top_sim_indices = text_image_sims.topk(k = 1).indices
top_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d = image_embed_dim)
top_image_embeds = image_embeds.gather(1, top_sim_indices)
return rearrange(top_image_embeds, 'b 1 d -> b d')
def forward(self, text, image, *args, **kwargs):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@@ -541,10 +704,17 @@ class ConvNextBlock(nn.Module):
super().__init__()
need_projection = dim != dim_out
self.mlp = nn.Sequential(
nn.GELU(),
nn.Linear(cond_dim, dim)
) if exists(cond_dim) else None
self.cross_attn = None
if exists(cond_dim):
self.cross_attn = EinopsToAndFrom(
'b c h w',
'b (h w) c',
CrossAttention(
dim = dim,
context_dim = cond_dim
)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
@@ -561,70 +731,164 @@ class ConvNextBlock(nn.Module):
def forward(self, x, cond = None):
h = self.ds_conv(x)
if exists(self.mlp):
if exists(self.cross_attn):
assert exists(cond)
condition = self.mlp(cond)
h = h + rearrange(condition, 'b c -> b c 1 1')
h = self.cross_attn(h, context = cond) + h
h = self.net(h)
return h + self.res_conv(x)
class CrossAttention(nn.Module):
def __init__(
self,
dim,
*,
context_dim = None,
dim_head = 64,
heads = 8,
dropout = 0.,
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
inner_dim = dim_head * heads
context_dim = default(context_dim, dim)
self.norm = RMSNorm(dim)
self.norm_context = RMSNorm(context_dim)
self.dropout = nn.Dropout(dropout)
self.null_kv = nn.Parameter(torch.randn(2, dim_head))
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)
def forward(self, x, context, mask = None):
b, n, device = *x.shape[:2], x.device
x = self.norm(x)
context = self.norm_context(context)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
# add null key / value for classifier free guidance in prior net
nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
k = torch.cat((nk, k), dim = -2)
v = torch.cat((nv, v), dim = -2)
q = q * self.scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
max_neg_value = -torch.finfo(sim.dtype).max
if exists(mask):
mask = F.pad(mask, (1, 0), value = True)
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, max_neg_value)
sim = sim - sim.amax(dim = -1, keepdim = True)
attn = sim.softmax(dim = -1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class Unet(nn.Module):
def __init__(
self,
dim,
*,
image_embed_dim,
time_dim = None,
cond_dim = None,
num_image_tokens = 4,
out_dim = None,
dim_mults=(1, 2, 4, 8),
channels = 3,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
self.channels = channels
dims = [channels, *map(lambda m: dim * m, dim_mults)]
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
time_dim = default(time_dim, dim)
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
self.time_mlp = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim)
nn.Linear(dim * 4, cond_dim),
Rearrange('b d -> b 1 d')
)
self.null_image_embed = nn.Parameter(torch.randn(image_embed_dim))
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
) if image_embed_dim != cond_dim else nn.Identity()
cond_dim = time_dim + image_embed_dim
self.text_to_cond = nn.LazyLinear(cond_dim)
# for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# layers
self.downs = nn.ModuleList([])
self.ups = nn.ModuleList([])
num_resolutions = len(in_out)
for ind, (dim_in, dim_out) in enumerate(in_out):
is_first = ind == 0
is_last = ind >= (num_resolutions - 1)
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, cond_dim = cond_dim, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', PreNormResidual(mid_dim, Attention(mid_dim)))
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 1)
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = cond_dim),
Upsample(dim_in) if not is_last else nn.Identity()
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
out_dim = default(out_dim, channels)
@@ -635,17 +899,16 @@ class Unet(nn.Module):
def forward_with_cond_scale(
self,
x,
*args,
cond_scale = 1.,
**kwargs
):
logits = self.forward(x, *args, **kwargs)
logits = self.forward(*args, **kwargs)
if cond_scale == 1:
return logits
null_logits = self.forward(x, *args, cond_drop_prob = 1., **kwargs)
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale
def forward(
@@ -654,41 +917,88 @@ class Unet(nn.Module):
time,
*,
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
batch_size, device = x.shape[0], x.device
t = self.time_mlp(time)
# add low resolution conditioning, if present
assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
# time conditioning
time_tokens = self.time_mlp(time)
# conditional dropout
cond_prob_mask = prob_mask_like((batch_size,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1 1')
# mask out image embedding depending on condition dropout
# for classifier free guidance
image_embed = torch.where(
rearrange(cond_prob_mask, 'b -> b 1'),
image_embed,
rearrange(self.null_image_embed, 'd -> 1 d')
image_tokens = self.image_to_cond(image_embed)
image_tokens = torch.where(
cond_prob_mask,
image_tokens,
self.null_image_embed
)
t = torch.cat((t, image_embed), dim = -1)
# take care of text encodings (optional)
if exists(text_encodings):
text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where(
cond_prob_mask,
text_tokens,
self.null_text_embed
)
# main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up
hiddens = []
for convnext, convnext2, downsample in self.downs:
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)
x = self.mid_block1(x, mid_c)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, t)
x = convnext2(x, t)
x = convnext(x, c)
x = convnext2(x, c)
x = upsample(x)
return self.final_conv(x)
@@ -696,28 +1006,50 @@ class Unet(nn.Module):
class Decoder(nn.Module):
def __init__(
self,
net,
unet,
*,
clip,
timesteps = 1000,
cond_drop_prob = 0.2,
loss_type = 'l1'
loss_type = 'l1',
beta_schedule = 'cosine',
image_sizes = None # for cascading ddpm, image size at each stage
):
super().__init__()
assert isinstance(clip, CLIP)
freeze_model_and_make_eval_(clip)
self.clip = clip
self.net = net
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.image_size = clip.image_size
self.unets = nn.ModuleList(unet)
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
assert len(self.unets) == len(image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {image_sizes}'
self.image_sizes = image_sizes
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
assert lowres_conditions == (False, *((True,) * (len(self.unets) - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
self.cond_drop_prob = cond_drop_prob
betas = cosine_beta_schedule(timesteps)
if beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "quadratic":
betas = quadratic_beta_schedule(timesteps)
elif beta_schedule == "jsd":
betas = 1.0 / torch.linspace(timesteps, 1, timesteps)
elif beta_schedule == "sigmoid":
betas = sigmoid_beta_schedule(timesteps)
else:
raise NotImplementedError()
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (0, 1), value = 1.)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
@@ -749,7 +1081,12 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:]
def get_image_embed(self, image):
image = resize_image_to(image, self.clip_image_size)
image_encoding = self.clip.visual_transformer(image)
image_cls = image_encoding[:, 0]
image_embed = self.clip.to_visual_latent(image_cls)
@@ -776,8 +1113,9 @@ class Decoder(nn.Module):
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, image_embed, clip_denoised = True, cond_scale = 1.):
x_recon = self.predict_start_from_noise(x, t = t, noise = self.net.forward_with_cond_scale(x, t, image_embed = image_embed, cond_scale = cond_scale))
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, cond_scale = 1.):
pred_noise = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
x_recon = self.predict_start_from_noise(x, t = t, noise = pred_noise)
if clip_denoised:
x_recon.clamp_(-1., 1.)
@@ -786,32 +1124,25 @@ class Decoder(nn.Module):
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, image_embed, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, clip_denoised = True, repeat_noise = False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, image_embed = image_embed, cond_scale = cond_scale, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
@torch.no_grad()
def p_sample_loop(self, shape, image_embed, cond_scale = 1):
def p_sample_loop(self, unet, shape, image_embed, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
img = self.p_sample(img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, cond_scale = cond_scale)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(unet, img, torch.full((b,), i, device = device, dtype = torch.long), image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
return img
@torch.no_grad()
def sample(self, image_embed, cond_scale = 1.):
batch_size = image_embed.shape[0]
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size), image_embed = image_embed, cond_scale = cond_scale)
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
@@ -820,15 +1151,17 @@ class Decoder(nn.Module):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, *, image_embed, noise = None):
def p_losses(self, unet, x_start, t, *, image_embed, lowres_cond_img = None, text_encodings = None, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
x_noisy = self.q_sample(x_start = x_start, t = t, noise = noise)
x_recon = self.net(
x_recon = unet(
x_noisy,
t,
image_embed = image_embed,
text_encodings = text_encodings,
lowres_cond_img = lowres_cond_img,
cond_drop_prob = self.cond_drop_prob
)
@@ -836,20 +1169,50 @@ class Decoder(nn.Module):
loss = F.l1_loss(noise, x_recon)
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
elif self.loss_type == "huber":
loss = F.smooth_l1_loss(noise, x_recon)
else:
raise NotImplementedError()
return loss
def forward(self, image):
b, device, img_size, = image.shape[0], image.device, self.image_size
check_shape(image, 'b c h w', h = img_size, w = img_size, c = self.channels)
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
batch_size = image_embed.shape[0]
channels = self.channels
text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img
def forward(self, image, text = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1
unet = self.unets[index]
target_image_size = self.image_sizes[index]
b, c, h, w, device, = *image.shape, image.device
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
loss = self.p_losses(image, times, image_embed = image_embed)
return loss
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)
# main class
@@ -864,11 +1227,12 @@ class DALLE2(nn.Module):
super().__init__()
assert isinstance(prior, DiffusionPrior)
assert isinstance(decoder, Decoder)
self.prior = prior.eval()
self.decoder = decoder.eval()
self.prior = prior
self.decoder = decoder
self.prior_num_samples = prior_num_samples
@torch.no_grad()
@eval_decorator
def forward(
self,
text,

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.7',
version = '0.0.21',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',
@@ -25,6 +25,7 @@ setup(
'click',
'einops>=0.4',
'einops-exts>=0.0.3',
'kornia>=0.5.4',
'pillow',
'torch>=1.10',
'torchvision',