Compare commits

...

13 Commits

4 changed files with 105 additions and 33 deletions

View File

@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a> There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install ## Install
@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip # decoder, which contains the unet and clip
decoder = Decoder( decoder = Decoder(
net = unet, unet = unet,
clip = clip, clip = clip,
timesteps = 100, timesteps = 100,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
@@ -182,9 +182,9 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings # now the diffusion prior can generate image embeddings from the text embeddings
``` ```
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, from which DALL-E2 is based). In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within the framework offered in this repository as so This can easily be used within this framework as so
```python ```python
import torch import torch
@@ -197,10 +197,10 @@ clip = CLIP(
dim_image = 512, dim_image = 512,
dim_latent = 512, dim_latent = 512,
num_text_tokens = 49408, num_text_tokens = 49408,
text_enc_depth = 1, text_enc_depth = 6,
text_seq_len = 256, text_seq_len = 256,
text_heads = 8, text_heads = 8,
visual_enc_depth = 1, visual_enc_depth = 6,
visual_image_size = 256, visual_image_size = 256,
visual_patch_size = 32, visual_patch_size = 32,
visual_heads = 8 visual_heads = 8
@@ -209,28 +209,28 @@ clip = CLIP(
# 2 unets for the decoder (a la cascading DDPM) # 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet( unet1 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8) dim_mults = (1, 2, 4, 8)
).cuda() ).cuda()
unet2 = Unet( unet2 = Unet(
dim = 16, dim = 32,
image_embed_dim = 512, image_embed_dim = 512,
lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off)
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16) dim_mults = (1, 2, 4, 8, 16)
).cuda() ).cuda()
# decoder, which contains the unet and clip # decoder, which contains the unet(s) and clip
decoder = Decoder( decoder = Decoder(
clip = clip, clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here) unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 100, timesteps = 1000,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
).cuda() ).cuda()
@@ -257,7 +257,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512) images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
``` ```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer) Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
```python ```python
from dalle2_pytorch import DALLE2 from dalle2_pytorch import DALLE2
@@ -349,8 +349,7 @@ unet2 = Unet(
image_embed_dim = 512, image_embed_dim = 512,
cond_dim = 128, cond_dim = 128,
channels = 3, channels = 3,
dim_mults = (1, 2, 4, 8, 16), dim_mults = (1, 2, 4, 8, 16)
lowres_cond = True
).cuda() ).cuda()
decoder = Decoder( decoder = Decoder(
@@ -410,12 +409,12 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest - [x] add efficient attention in unet
- [ ] make unet more configurable - [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [ ] train on a toy task, offer in colab - [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
## Citations ## Citations
@@ -464,4 +463,12 @@ Offer training wrappers
} }
``` ```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -1,6 +1,7 @@
import math import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer from dalle2_pytorch.tokenizer import tokenizer
@@ -104,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
""" """
steps = timesteps + 1 steps = timesteps + 1
x = torch.linspace(0, steps, steps) x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999) return torch.clip(betas, 0, 0.999)
@@ -797,6 +798,20 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out) return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module): class Unet(nn.Module):
def __init__( def __init__(
self, self,
@@ -811,15 +826,24 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear', lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1, blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
): ):
super().__init__() super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion # for eventual cascading diffusion
self.lowres_cond = lowres_cond self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma)) self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions # determine dimensions
@@ -867,6 +891,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0), ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -883,6 +908,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in) Upsample(dim_in)
])) ]))
@@ -893,6 +919,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1) nn.Conv2d(dim, out_dim, 1)
) )
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale( def forward_with_cond_scale(
self, self,
*args, *args,
@@ -915,7 +950,9 @@ class Unet(nn.Module):
image_embed, image_embed,
lowres_cond_img = None, lowres_cond_img = None,
text_encodings = None, text_encodings = None,
cond_drop_prob = 0. cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
): ):
batch_size, device = x.shape[0], x.device batch_size, device = x.shape[0], x.device
@@ -926,7 +963,9 @@ class Unet(nn.Module):
if exists(lowres_cond_img): if exists(lowres_cond_img):
if self.training: if self.training:
# when training, blur the low resolution conditional image # when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img) blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode) lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1) x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -974,8 +1013,9 @@ class Unet(nn.Module):
hiddens = [] hiddens = []
for convnext, convnext2, downsample in self.downs: for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c) x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c) x = convnext2(x, c)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
@@ -987,9 +1027,10 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c) x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups: for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c) x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c) x = convnext2(x, c)
x = upsample(x) x = upsample(x)
@@ -1014,7 +1055,17 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size self.clip_image_size = clip.image_size
self.channels = clip.image_channels self.channels = clip.image_channels
self.unets = cast_tuple(unet) # automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,)) image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes))) image_sizes = tuple(sorted(set(image_sizes)))
@@ -1183,7 +1234,7 @@ class Decoder(nn.Module):
return img return img
def forward(self, image, text = None, unet_number = None): def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets) assert 1 <= unet_number <= len(self.unets)
@@ -1199,8 +1250,10 @@ class Decoder(nn.Module):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long) times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image) if not exists(image_embed):
text_encodings = self.get_text_encodings(text) if exists(text) else None image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size) ddpm_image = resize_image_to(image, target_image_size)

View File

@@ -0,0 +1,12 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.20', version = '0.0.26',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',