Compare commits

...

7 Commits

3 changed files with 57 additions and 27 deletions

View File

@@ -109,7 +109,7 @@ unet = Unet(
# decoder, which contains the unet and clip
decoder = Decoder(
net = unet,
unet = unet,
clip = clip,
timesteps = 100,
cond_drop_prob = 0.2
@@ -182,9 +182,9 @@ loss.backward()
# now the diffusion prior can generate image embeddings from the text embeddings
```
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, from which DALL-E2 is based).
In the paper, they actually used a <a href="https://cascaded-diffusion.github.io/">recently discovered technique</a>, from <a href="http://www.jonathanho.me/">Jonathan Ho</a> himself (original author of DDPMs, the core technique used in DALL-E v2) for high resolution image synthesis.
This can easily be used within the framework offered in this repository as so
This can easily be used within this framework as so
```python
import torch
@@ -197,10 +197,10 @@ clip = CLIP(
dim_image = 512,
dim_latent = 512,
num_text_tokens = 49408,
text_enc_depth = 1,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
visual_enc_depth = 1,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
@@ -209,28 +209,28 @@ clip = CLIP(
# 2 unets for the decoder (a la cascading DDPM)
unet1 = Unet(
dim = 16,
dim = 32,
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8)
).cuda()
unet2 = Unet(
dim = 16,
dim = 32,
image_embed_dim = 512,
lowres_cond = True, # subsequence unets must have this turned on (and first unet must have this turned off)
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16)
).cuda()
# decoder, which contains the unet and clip
# decoder, which contains the unet(s) and clip
decoder = Decoder(
clip = clip,
unet = (unet1, unet2), # insert both unets in order of low resolution to highest resolution (you can have as many stages as you want here)
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second
timesteps = 100,
image_sizes = (256, 512), # resolutions, 256 for first unet, 512 for second. these must be unique and in ascending order (matches with the unets passed in)
timesteps = 1000,
cond_drop_prob = 0.2
).cuda()
@@ -257,7 +257,7 @@ mock_image_embed = torch.randn(1, 512).cuda()
images = decoder.sample(mock_image_embed) # (1, 3, 512, 512)
```
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which both contains `CLIP`, a unet, and a causal transformer)
Finally, to generate the DALL-E2 images from text. Insert the trained `DiffusionPrior` as well as the `Decoder` (which wraps `CLIP`, the causal transformer, and unet(s))
```python
from dalle2_pytorch import DALLE2
@@ -349,8 +349,7 @@ unet2 = Unet(
image_embed_dim = 512,
cond_dim = 128,
channels = 3,
dim_mults = (1, 2, 4, 8, 16),
lowres_cond = True
dim_mults = (1, 2, 4, 8, 16)
).cuda()
decoder = Decoder(
@@ -410,12 +409,10 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] use an image resolution cutoff and do cross attention conditioning only if resources allow, and MLP + sum conditioning on rest
- [ ] make unet more configurable
- [ ] train on a toy task, offer in colab
- [ ] add attention to unet - apply some personal tricks with efficient attention - use the sparse attention mechanism from https://github.com/lucidrains/vit-pytorch#maxvit
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] consider U2-net for decoder https://arxiv.org/abs/2005.09007 (also in separate file as experimental) build out https://github.com/lucidrains/x-unet
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
## Citations

View File

@@ -1,6 +1,7 @@
import math
from tqdm import tqdm
from inspect import isfunction
from functools import partial
import torch
import torch.nn.functional as F
@@ -11,7 +12,7 @@ from einops.layers.torch import Rearrange
from einops_exts import rearrange_many, repeat_many, check_shape
from einops_exts.torch import EinopsToAndFrom
from kornia.filters.gaussian import GaussianBlur2d
from kornia.filters import gaussian_blur2d
from dalle2_pytorch.tokenizer import tokenizer
@@ -811,15 +812,22 @@ class Unet(nn.Module):
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
self._locals = locals()
del self._locals['self']
del self._locals['__class__']
# for eventual cascading diffusion
self.lowres_cond = lowres_cond
self.lowres_cond_upsample_mode = lowres_cond_upsample_mode
self.lowres_cond_blur = GaussianBlur2d((3, 3), (blur_sigma, blur_sigma))
self.lowres_blur_kernel_size = blur_kernel_size
self.lowres_blur_sigma = blur_sigma
# determine dimensions
@@ -893,6 +901,15 @@ class Unet(nn.Module):
nn.Conv2d(dim, out_dim, 1)
)
# if the current settings for the unet are not correct
# for cascading DDPM, then reinit the unet with the right settings
def force_lowres_cond(self, lowres_cond):
if lowres_cond == self.lowres_cond:
return self
updated_kwargs = {**self._locals, 'lowres_cond': lowres_cond}
return self.__class__(**updated_kwargs)
def forward_with_cond_scale(
self,
*args,
@@ -915,7 +932,9 @@ class Unet(nn.Module):
image_embed,
lowres_cond_img = None,
text_encodings = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
blur_sigma = None,
blur_kernel_size = None
):
batch_size, device = x.shape[0], x.device
@@ -926,7 +945,9 @@ class Unet(nn.Module):
if exists(lowres_cond_img):
if self.training:
# when training, blur the low resolution conditional image
lowres_cond_img = self.lowres_cond_blur(lowres_cond_img)
blur_sigma = default(blur_sigma, self.lowres_blur_sigma)
blur_kernel_size = default(blur_kernel_size, self.lowres_blur_kernel_size)
lowres_cond_img = gaussian_blur2d(lowres_cond_img, cast_tuple(blur_kernel_size, 2), cast_tuple(blur_sigma, 2))
lowres_cond_img = resize_image_to(lowres_cond_img, x.shape[-2:], mode = self.lowres_cond_upsample_mode)
x = torch.cat((x, lowres_cond_img), dim = 1)
@@ -1014,7 +1035,17 @@ class Decoder(nn.Module):
self.clip_image_size = clip.image_size
self.channels = clip.image_channels
self.unets = cast_tuple(unet)
# automatically take care of ensuring that first unet is unconditional
# while the rest of the unets are conditioned on the low resolution image produced by previous unet
self.unets = nn.ModuleList([])
for ind, one_unet in enumerate(cast_tuple(unet)):
is_first = ind == 0
one_unet = one_unet.force_lowres_cond(not is_first)
self.unets.append(one_unet)
# unet image sizes
image_sizes = default(image_sizes, (clip.image_size,))
image_sizes = tuple(sorted(set(image_sizes)))
@@ -1183,7 +1214,7 @@ class Decoder(nn.Module):
return img
def forward(self, image, text = None, unet_number = None):
def forward(self, image, text = None, image_embed = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
@@ -1199,7 +1230,9 @@ class Decoder(nn.Module):
times = torch.randint(0, self.num_timesteps, (b,), device = device, dtype = torch.long)
image_embed = self.get_image_embed(image)
if not exists(image_embed):
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
lowres_cond_img = image if index > 0 else None

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.20',
version = '0.0.23',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',