Compare commits

...

7 Commits

5 changed files with 99 additions and 23 deletions

View File

@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a> There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install ## Install
@@ -409,9 +409,11 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference) - [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper) - [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) - [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab
## Citations ## Citations
@@ -461,4 +463,12 @@ Offer training wrappers
} }
``` ```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a> *Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -6,4 +6,4 @@ def main():
@click.command() @click.command()
@click.argument('text') @click.argument('text')
def dream(text): def dream(text):
return image return 'not ready yet'

View File

@@ -2,6 +2,7 @@ import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -105,8 +106,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
""" """
steps = timesteps + 1 steps = timesteps + 1
x = torch.linspace(0, steps, steps) x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0] alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999) return torch.clip(betas, 0, 0.999)
@@ -798,6 +799,20 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)') out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out) return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module): class Unet(nn.Module):
def __init__( def __init__(
self, self,
@@ -806,6 +821,7 @@ class Unet(nn.Module):
image_embed_dim, image_embed_dim,
cond_dim = None, cond_dim = None,
num_image_tokens = 4, num_image_tokens = 4,
num_time_tokens = 2,
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
@@ -813,7 +829,11 @@ class Unet(nn.Module):
lowres_cond_upsample_mode = 'bilinear', lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1, blur_sigma = 0.1,
blur_kernel_size = 3, blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
cond_on_text_encodings = False,
cond_on_image_embeds = False,
): ):
super().__init__() super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM # save locals to take care of some hyperparameters for cascading DDPM
@@ -846,8 +866,8 @@ class Unet(nn.Module):
SinusoidalPosEmb(dim), SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4), nn.Linear(dim, dim * 4),
nn.GELU(), nn.GELU(),
nn.Linear(dim * 4, cond_dim), nn.Linear(dim * 4, cond_dim * num_time_tokens),
Rearrange('b d -> b 1 d') Rearrange('b (r d) -> b r d', r = num_time_tokens)
) )
self.image_to_cond = nn.Sequential( self.image_to_cond = nn.Sequential(
@@ -857,6 +877,12 @@ class Unet(nn.Module):
self.text_to_cond = nn.LazyLinear(cond_dim) self.text_to_cond = nn.LazyLinear(cond_dim)
# finer control over whether to condition on image embeddings and text encodings
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
self.cond_on_text_encodings = cond_on_text_encodings
self.cond_on_image_embeds = cond_on_image_embeds
# for classifier free guidance # for classifier free guidance
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
@@ -875,6 +901,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0), ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -891,6 +918,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in) Upsample(dim_in)
])) ]))
@@ -964,17 +992,22 @@ class Unet(nn.Module):
# mask out image embedding depending on condition dropout # mask out image embedding depending on condition dropout
# for classifier free guidance # for classifier free guidance
image_tokens = self.image_to_cond(image_embed) image_tokens = None
image_tokens = torch.where( if self.cond_on_image_embeds:
cond_prob_mask, image_tokens = self.image_to_cond(image_embed)
image_tokens,
self.null_image_embed image_tokens = torch.where(
) cond_prob_mask,
image_tokens,
self.null_image_embed
)
# take care of text encodings (optional) # take care of text encodings (optional)
if exists(text_encodings): text_tokens = None
if exists(text_encodings) and self.cond_on_text_encodings:
text_tokens = self.text_to_cond(text_encodings) text_tokens = self.text_to_cond(text_encodings)
text_tokens = torch.where( text_tokens = torch.where(
cond_prob_mask, cond_prob_mask,
@@ -984,19 +1017,23 @@ class Unet(nn.Module):
# main conditioning tokens (c) # main conditioning tokens (c)
c = torch.cat((time_tokens, image_tokens), dim = -2) c = time_tokens
if exists(image_tokens):
c = torch.cat((c, image_tokens), dim = -2)
# text and image conditioning tokens (mid_c) # text and image conditioning tokens (mid_c)
# to save on compute, only do cross attention based conditioning on the inner most layers of the Unet # to save on compute, only do cross attention based conditioning on the inner most layers of the Unet
mid_c = c if not exists(text_encodings) else torch.cat((c, text_tokens), dim = -2) mid_c = c if not exists(text_tokens) else torch.cat((c, text_tokens), dim = -2)
# go through the layers of the unet, down and up # go through the layers of the unet, down and up
hiddens = [] hiddens = []
for convnext, convnext2, downsample in self.downs: for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c) x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c) x = convnext2(x, c)
hiddens.append(x) hiddens.append(x)
x = downsample(x) x = downsample(x)
@@ -1008,9 +1045,10 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c) x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups: for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1) x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c) x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c) x = convnext2(x, c)
x = upsample(x) x = upsample(x)
@@ -1104,6 +1142,20 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
@contextmanager
def one_unet_in_gpu(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
self.cuda()
self.unets.cpu()
unet = self.unets[index]
unet.cuda()
yield
self.unets.cpu()
def get_text_encodings(self, text): def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:] return text_encodings[:, 1:]
@@ -1208,9 +1260,11 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) else None text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size) for ind, (unet, image_size) in tqdm(enumerate(zip(self.unets, self.image_sizes))):
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) with self.one_unet_in_gpu(ind + 1):
shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img return img

View File

@@ -0,0 +1,12 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.24', version = '0.0.28',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',