mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-12 11:34:29 +01:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27a33e1b20 | ||
|
|
6f941a219a | ||
|
|
ddde8ca1bf | ||
|
|
c26b77ad20 | ||
|
|
c5b4aab8e5 | ||
|
|
a35c309b5f | ||
|
|
55bdcb98b9 | ||
|
|
82328f16cd | ||
|
|
6fee4fce6e | ||
|
|
a54e309269 | ||
|
|
c6bfd7fdc8 | ||
|
|
960a79857b | ||
|
|
7214df472d | ||
|
|
00ae50999b | ||
|
|
6cddefad26 |
47
README.md
47
README.md
@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
|
||||
|
||||
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
|
||||
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
|
||||
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
|
||||
|
||||
## Install
|
||||
|
||||
@@ -109,7 +109,7 @@ unet = Unet(
|
||||
# decoder, which contains the unet and clip
|
||||
|
||||
decoder = Decoder(
|
||||
net = unet,
|
||||
unet = unet,
|
||||
clip = clip,
|
||||
timesteps = 100,
|
||||
cond_drop_prob = 0.2
|
||||
@@ -182,9 +182,9 @@ loss.backward()
|
||||
# now the diffusion prior can generate image embeddings from the text embeddings
|
||||
```
|
||||
|
||||
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, from which DALL-E2 is based).
|
||||
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
|
||||
|
||||
This can easily be used within the framework offered in this repository as so
|
||||
This can easily be used within this framework as so
|
||||
|
||||
```python
|
||||
import torch
|
||||
@@ -197,10 +197,10 @@ clip = CLIP(
|
||||
dim_image = 512,
|
||||
dim_latent = 512,
|
||||
num_text_tokens = 49408,
|
||||
text_enc_depth = 1,
|
||||
text_enc_depth = 6,
|
||||
text_seq_len = 256,
|
||||
text_heads = 8,
|
||||
visual_enc_depth = 1,
|
||||
visual_enc_depth = 6,
|
||||
visual_image_size = 256,
|
||||
visual_patch_size = 32,
|
||||
visual_heads = 8
|
||||
@@ -209,28 +209,28 @@ clip = CLIP(
|
||||
# 2 unets for the decoder (a la cascading DDPM)
|
||||
|
||||
unet1 = Unet(
|
||||
dim = 16,
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8)
|
||||
).cuda()
|
||||
|
||||
unet2 = Unet(
|
||||
dim = 16,
|
||||
dim = 32,
|
||||
image_embed_dim = 512,
|
||||
lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off)
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
# decoder, which contains the unet and clip
|
||||
# decoder, which contains the unet(s) and clip
|
||||
|
||||
decoder = Decoder(
|
||||
clip = clip,
|
||||
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second
|
||||
timesteps = 100,
|
||||
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
|
||||
timesteps = 1000,
|
||||
cond_drop_prob = 0.2
|
||||
).cuda()
|
||||
|
||||
@@ -257,7 +257,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
|
||||
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
|
||||
```
|
||||
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
|
||||
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
|
||||
|
||||
```python
|
||||
from dalle2_pytorch import DALLE2
|
||||
@@ -349,8 +349,7 @@ unet2 = Unet(
|
||||
image_embed_dim = 512,
|
||||
cond_dim = 128,
|
||||
channels = 3,
|
||||
dim_mults = (1, 2, 4, 8, 16),
|
||||
lowres_cond = True
|
||||
dim_mults = (1, 2, 4, 8, 16)
|
||||
).cuda()
|
||||
|
||||
decoder = Decoder(
|
||||
@@ -410,12 +409,12 @@ Offer training wrappers
|
||||
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
|
||||
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
|
||||
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
|
||||
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
|
||||
- [ ] make unet more configurable
|
||||
- [ ] train on a toy task, offer in colab
|
||||
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
|
||||
- [x] add efficient attention in unet
|
||||
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
|
||||
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
|
||||
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
|
||||
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
|
||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||
- [ ] train on a toy task, offer in colab
|
||||
|
||||
## Citations
|
||||
|
||||
@@ -464,4 +463,12 @@ Offer training wrappers
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Tu2022MaxViTMV,
|
||||
title = {MaxViT: Multi-Axis Vision Transformer},
|
||||
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
|
||||
year = {2022}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>
|
||||
|
||||
@@ -6,4 +6,4 @@ def main():
|
||||
@click.command()
|
||||
@click.argument('text')
|
||||
def dream(text):
|
||||
return image
|
||||
return 'not ready yet'
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
from inspect import isfunction
|
||||
from functools import partial
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -11,7 +13,7 @@ from einops.layers.torch import Rearrange
|
||||
from einops_exts import rearrange_many, repeat_many, check_shape
|
||||
from einops_exts.torch import EinopsToAndFrom
|
||||
|
||||
from kornia.filters.gaussian import GaussianBlur2d
|
||||
from kornia.filters import gaussian_blur2d
|
||||
|
||||
from dalle2_pytorch.tokenizer import tokenizer
|
||||
|
||||
@@ -104,8 +106,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, steps, steps)
|
||||
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
x = torch.linspace(0, timesteps, steps)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
@@ -797,6 +799,20 @@ class CrossAttention(nn.Module):
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
return self.to_out(out)
|
||||
|
||||
class GridAttention(nn.Module):
|
||||
def __init__(self, *args, window_size = 8, **kwargs):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.attn = Attention(*args, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
wsz = self.window_size
|
||||
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
|
||||
out = self.attn(x)
|
||||
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
|
||||
return out
|
||||
|
||||
class Unet(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -805,21 +821,33 @@ class Unet(nn.Module):
|
||||
image_embed_dim,
|
||||
cond_dim = None,
|
||||
num_image_tokens = 4,
|
||||
num_time_tokens = 2,
|
||||
out_dim = None,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
channels = 3,
|
||||
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
|
||||
lowres_cond_upsample_mode = 'bilinear',
|
||||
blur_sigma = 0.1,
|
||||
blur_kernel_size = 3,
|
||||
sparse_attn = False,
|
||||
sparse_attn_window = 8, # window size for sparse attention
|
||||
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
|
||||
cond_on_text_encodings = False,
|
||||
cond_on_image_embeds = False,
|
||||
):
|
||||
super().__init__()
|
||||
# save locals to take care of some hyperparameters for cascading DDPM
|
||||
|
||||
self._locals = locals()
|
||||
del self._locals['self']
|
||||
del self._locals['__class__']
|
||||
|
||||
# for eventual cascading diffusion
|
||||
|
||||
self.lowres_cond = lowres_cond
|
||||
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
|
||||
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
|
||||
self.lowres_blur_kernel_size = blur_kernel_size
|
||||
self.lowres_blur_sigma = blur_sigma
|
||||
|
||||
# determine dimensions
|
||||
|
||||
@@ -838,8 +866,8 @@ class Unet(nn.Module):
|
||||
SinusoidalPosEmb(dim),
|
||||
nn.Linear(dim, dim * 4),
|
||||
nn.GELU(),
|
||||
nn.Linear(dim * 4, cond_dim),
|
||||
Rearrange('b d -> b 1 d')
|
||||
nn.Linear(dim * 4, cond_dim * num_time_tokens),
|
||||
Rearrange('b (r d) -> b r d', r = num_time_tokens)
|
||||
)
|
||||
|
||||
self.image_to_cond = nn.Sequential(
|
||||
@@ -849,6 +877,12 @@ class Unet(nn.Module):
|
||||
|
||||
self.text_to_cond = nn.LazyLinear(cond_dim)
|
||||
|
||||
# finer control over whether to condition on image embeddings and text encodings
|
||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||
|
||||
self.cond_on_text_encodings = cond_on_text_encodings
|
||||
self.cond_on_image_embeds = cond_on_image_embeds
|
||||
|
||||
# for classifier free guidance
|
||||
|
||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||
@@ -867,6 +901,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.downs.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
|
||||
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
|
||||
Downsample(dim_out) if not is_last else nn.Identity()
|
||||
]))
|
||||
@@ -883,6 +918,7 @@ class Unet(nn.Module):
|
||||
|
||||
self.ups.append(nn.ModuleList([
|
||||
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
|
||||
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
|
||||
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
|
||||
Upsample(dim_in)
|
||||
]))
|
||||
@@ -893,6 +929,15 @@ class Unet(nn.Module):
|
||||
nn.Conv2d(dim, out_dim, 1)
|
||||
)
|
||||
|
||||
# if the current settings for the unet are not correct
|
||||
# for cascading DDPM, then reinit the unet with the right settings
|
||||
def force_lowres_cond(self, lowres_cond):
|
||||
if lowres_cond == self.lowres_cond:
|
||||
return self
|
||||
|
||||
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
|
||||
return self.__class__(**updated_kwargs)
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -915,7 +960,9 @@ class Unet(nn.Module):
|
||||
image_embed,
|
||||
lowres_cond_img = None,
|
||||
text_encodings = None,
|
||||
cond_drop_prob = 0.
|
||||
cond_drop_prob = 0.,
|
||||
blur_sigma = None,
|
||||
blur_kernel_size = None
|
||||
):
|
||||
batch_size, device = x.shape[0], x.device
|
||||
|
||||
@@ -926,7 +973,9 @@ class Unet(nn.Module):
|
||||
if exists(lowres_cond_img):
|
||||
if self.training:
|
||||
# when training, blur the low resolution conditional image
|
||||
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
|
||||
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
|
||||
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
|
||||
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
|
||||
|
||||
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
|
||||
x = torch.cat((x, lowres_cond_img), dim = 1)
|
||||
@@ -943,17 +992,22 @@ class Unet(nn.Module):
|
||||
# mask out image embedding depending on condition dropout
|
||||
# for classifier free guidance
|
||||
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
image_tokens = None
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
if self.cond_on_image_embeds:
|
||||
image_tokens = self.image_to_cond(image_embed)
|
||||
|
||||
image_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
image_tokens,
|
||||
self.null_image_embed
|
||||
)
|
||||
|
||||
# take care of text encodings (optional)
|
||||
|
||||
if exists(text_encodings):
|
||||
text_tokens = None
|
||||
|
||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||
text_tokens = self.text_to_cond(text_encodings)
|
||||
text_tokens = torch.where(
|
||||
cond_prob_mask,
|
||||
@@ -963,19 +1017,23 @@ class Unet(nn.Module):
|
||||
|
||||
# main conditioning tokens (c)
|
||||
|
||||
c = torch.cat((time_tokens, image_tokens), dim = -2)
|
||||
c = time_tokens
|
||||
|
||||
if exists(image_tokens):
|
||||
c = torch.cat((c, image_tokens), dim = -2)
|
||||
|
||||
# text and image conditioning tokens (mid_c)
|
||||
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
|
||||
|
||||
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2)
|
||||
mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
|
||||
|
||||
# go through the layers of the unet, down and up
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
for convnext, sparse_attn, convnext2, downsample in self.downs:
|
||||
x = convnext(x, c)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
hiddens.append(x)
|
||||
x = downsample(x)
|
||||
@@ -987,9 +1045,10 @@ class Unet(nn.Module):
|
||||
|
||||
x = self.mid_block2(x, mid_c)
|
||||
|
||||
for convnext, convnext2, upsample in self.ups:
|
||||
for convnext, sparse_attn, convnext2, upsample in self.ups:
|
||||
x = torch.cat((x, hiddens.pop()), dim=1)
|
||||
x = convnext(x, c)
|
||||
x = sparse_attn(x)
|
||||
x = convnext2(x, c)
|
||||
x = upsample(x)
|
||||
|
||||
@@ -1014,7 +1073,17 @@ class Decoder(nn.Module):
|
||||
self.clip_image_size = clip.image_size
|
||||
self.channels = clip.image_channels
|
||||
|
||||
self.unets = cast_tuple(unet)
|
||||
# automatically take care of ensuring that first unet is unconditional
|
||||
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
|
||||
|
||||
self.unets = nn.ModuleList([])
|
||||
for ind, one_unet in enumerate(cast_tuple(unet)):
|
||||
is_first = ind == 0
|
||||
one_unet = one_unet.force_lowres_cond(not is_first)
|
||||
self.unets.append(one_unet)
|
||||
|
||||
# unet image sizes
|
||||
|
||||
image_sizes = default(image_sizes, (clip.image_size,))
|
||||
image_sizes = tuple(sorted(set(image_sizes)))
|
||||
|
||||
@@ -1073,6 +1142,20 @@ class Decoder(nn.Module):
|
||||
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
|
||||
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
|
||||
|
||||
@contextmanager
|
||||
def one_unet_in_gpu(self, unet_number):
|
||||
assert 0 < unet_number <= len(self.unets)
|
||||
index = unet_number - 1
|
||||
self.cuda()
|
||||
self.unets.cpu()
|
||||
|
||||
unet = self.unets[index]
|
||||
unet.cuda()
|
||||
|
||||
yield
|
||||
|
||||
self.unets.cpu()
|
||||
|
||||
def get_text_encodings(self, text):
|
||||
text_encodings = self.clip.text_transformer(text)
|
||||
return text_encodings[:, 1:]
|
||||
@@ -1177,13 +1260,15 @@ class Decoder(nn.Module):
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
|
||||
img = None
|
||||
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
|
||||
with self.one_unet_in_gpu(ind + 1):
|
||||
shape = (batch_size, channels, image_size, image_size)
|
||||
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
|
||||
|
||||
return img
|
||||
|
||||
def forward(self, image, text = None, unet_number = None):
|
||||
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
|
||||
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
|
||||
unet_number = default(unet_number, 1)
|
||||
assert 1 <= unet_number <= len(self.unets)
|
||||
@@ -1199,8 +1284,10 @@ class Decoder(nn.Module):
|
||||
|
||||
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
|
||||
|
||||
image_embed = self.get_image_embed(image)
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) else None
|
||||
if not exists(image_embed):
|
||||
image_embed = self.get_image_embed(image)
|
||||
|
||||
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
|
||||
|
||||
lowres_cond_img = image if index > 0 else None
|
||||
ddpm_image = resize_image_to(image, target_image_size)
|
||||
|
||||
12
dalle2_pytorch/latent_diffusion.py
Normal file
12
dalle2_pytorch/latent_diffusion.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
class LatentDiffusion(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
Reference in New Issue
Block a user