Compare commits

...

6 Commits

Author SHA1 Message Date
Phil Wang
ddde8ca1bf fix cosine bbeta schedule, thanks to @Zhengxinyang 2022-04-19 20:54:28 -07:00
Phil Wang
c26b77ad20 todo 2022-04-19 13:07:32 -07:00
Phil Wang
c5b4aab8e5 intent 2022-04-19 11:00:05 -07:00
Phil Wang
a35c309b5f add sparse attention layers in between convnext blocks in unet (grid like attention, used in mobilevit, maxvit [bytedance ai], as well as a growing number of attention-based GANs) 2022-04-19 09:49:03 -07:00
Phil Wang
55bdcb98b9 scaffold for latent diffusion 2022-04-19 09:26:58 -07:00
Phil Wang
82328f16cd same for text encodings for decoder ddpm training 2022-04-18 14:41:02 -07:00
4 changed files with 51 additions and 9 deletions

View File

@@ -14,7 +14,7 @@ It may also explore an extension of using <a href="https://huggingface.co/spaces
Please join <a href="https://discord.gg/xBPBXfcFHd"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=discord&logoColor=white"></a> if you are interested in helping out with the replication
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>
There was enough interest for a Jax version. It will be completed after the Pytorch version shows signs of life on my toy tasks. <a href="https://github.com/lucidrains/dalle2-jax">Placeholder repository</a>. I will also eventually extend this to <a href="https://github.com/lucidrains/dalle2-video">text to video</a>, once the repository is in a good place.
## Install
@@ -409,9 +409,11 @@ Offer training wrappers
- [x] augment unet so that it can also be conditioned on text encodings (although in paper they hinted this didn't make much a difference)
- [x] figure out all the current bag of tricks needed to make DDPMs great (starting with the blur trick mentioned in paper)
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet
- [ ] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, add efficient attention (conditional on resolution), port all learnings over to https://github.com/lucidrains/x-unet
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab
## Citations
@@ -461,4 +463,12 @@ Offer training wrappers
}
```
```bibtex
@inproceedings{Tu2022MaxViTMV,
title = {MaxViT: Multi-Axis Vision Transformer},
author = {Zhe-Wei Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
year = {2022}
}
```
*Creating noise from data is easy; creating data from noise is generative modeling.* - Yang Song's <a href="https://arxiv.org/abs/2011.13456">paper</a>

View File

@@ -105,8 +105,8 @@ def cosine_beta_schedule(timesteps, s = 0.008):
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, steps, steps)
alphas_cumprod = torch.cos(((x / steps) + s) / (1 + s) * torch.pi * 0.5) ** 2
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
@@ -798,6 +798,20 @@ class CrossAttention(nn.Module):
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class GridAttention(nn.Module):
def __init__(self, *args, window_size = 8, **kwargs):
super().__init__()
self.window_size = window_size
self.attn = Attention(*args, **kwargs)
def forward(self, x):
h, w = x.shape[-2:]
wsz = self.window_size
x = rearrange(x, 'b c (w1 h) (w2 w) -> (b h w) (w1 w2) c', w1 = wsz, w2 = wsz)
out = self.attn(x)
out = rearrange(out, '(b h w) (w1 w2) c -> b c (w1 h) (w2 w)', w1 = wsz, w2 = wsz, h = h // wsz, w = w // wsz)
return out
class Unet(nn.Module):
def __init__(
self,
@@ -813,6 +827,8 @@ class Unet(nn.Module):
lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1,
blur_kernel_size = 3,
sparse_attn = False,
sparse_attn_window = 8, # window size for sparse attention
attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
):
super().__init__()
@@ -875,6 +891,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
@@ -891,6 +908,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in)
]))
@@ -995,8 +1013,9 @@ class Unet(nn.Module):
hiddens = []
for convnext, convnext2, downsample in self.downs:
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
hiddens.append(x)
x = downsample(x)
@@ -1008,9 +1027,10 @@ class Unet(nn.Module):
x = self.mid_block2(x, mid_c)
for convnext, convnext2, upsample in self.ups:
for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = sparse_attn(x)
x = convnext2(x, c)
x = upsample(x)
@@ -1214,7 +1234,7 @@ class Decoder(nn.Module):
return img
def forward(self, image, text = None, image_embed = None, unet_number = None):
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
@@ -1233,7 +1253,7 @@ class Decoder(nn.Module):
if not exists(image_embed):
image_embed = self.get_image_embed(image)
text_encodings = self.get_text_encodings(text) if exists(text) else None
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None
ddpm_image = resize_image_to(image, target_image_size)

View File

@@ -0,0 +1,12 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange
class LatentDiffusion(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.23',
version = '0.0.26',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',