Compare commits

..

5 Commits

3 changed files with 45 additions and 18 deletions

View File

@@ -411,8 +411,8 @@ Offer training wrappers
- [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions - [x] build the cascading ddpm by having Decoder class manage multiple unets at different resolutions
- [x] add efficient attention in unet - [x] add efficient attention in unet
- [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning) - [x] be able to finely customize what to condition on (text, image embed) for specific unet in the cascade (super resolution ddpms near the end may not need too much conditioning)
- [ ] build out latent diffusion architecture in separate file, as it is not faithful to dalle-2 (but offer it as as setting) - [x] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately)
- [ ] offload unets not being trained on to CPU for memory efficiency (for training each resolution unets separately) - [ ] build out latent diffusion architecture, make it completely optional (additional autoencoder + some regularizations [kl and vq regs]) (figure out if latent diffusion + cascading ddpm can be used in conjunction)
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] train on a toy task, offer in colab - [ ] train on a toy task, offer in colab

View File

@@ -2,6 +2,7 @@ import math
from tqdm import tqdm from tqdm import tqdm
from inspect import isfunction from inspect import isfunction
from functools import partial from functools import partial
from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -463,11 +464,11 @@ class DiffusionPrior(nn.Module):
net, net,
*, *,
clip, clip,
timesteps=1000, timesteps = 1000,
cond_drop_prob=0.2, cond_drop_prob = 0.2,
loss_type="l1", loss_type = "l1",
predict_x0=True, predict_x0 = True,
beta_schedule="cosine", beta_schedule = "cosine",
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
@@ -824,6 +825,8 @@ class Unet(nn.Module):
out_dim = None, out_dim = None,
dim_mults=(1, 2, 4, 8), dim_mults=(1, 2, 4, 8),
channels = 3, channels = 3,
attn_dim_head = 32,
attn_heads = 8,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_cond_upsample_mode = 'bilinear', lowres_cond_upsample_mode = 'bilinear',
blur_sigma = 0.1, blur_sigma = 0.1,
@@ -887,6 +890,10 @@ class Unet(nn.Module):
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim)) self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim)) self.null_text_embed = nn.Parameter(torch.randn(1, 1, cond_dim))
# attention related params
attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head)
# layers # layers
self.downs = nn.ModuleList([]) self.downs = nn.ModuleList([])
@@ -900,7 +907,7 @@ class Unet(nn.Module):
self.downs.append(nn.ModuleList([ self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0), ConvNextBlock(dim_in, dim_out, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim), ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity() Downsample(dim_out) if not is_last else nn.Identity()
])) ]))
@@ -908,7 +915,7 @@ class Unet(nn.Module):
mid_dim = dims[-1] mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim))) if attend_at_middle else None self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim) self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
@@ -917,7 +924,7 @@ class Unet(nn.Module):
self.ups.append(nn.ModuleList([ self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window)) if sparse_attn else nn.Identity(), Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim), ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
Upsample(dim_in) Upsample(dim_in)
])) ]))
@@ -1141,6 +1148,25 @@ class Decoder(nn.Module):
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
def get_unet(self, unet_number):
assert 0 < unet_number <= len(self.unets)
index = unet_number - 1
return self.unets[index]
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
if exists(unet_number):
unet = self.get_unet(unet_number)
self.cuda()
self.unets.cpu()
unet.cuda()
yield
unet.cpu()
def get_text_encodings(self, text): def get_text_encodings(self, text):
text_encodings = self.clip.text_transformer(text) text_encodings = self.clip.text_transformer(text)
return text_encodings[:, 1:] return text_encodings[:, 1:]
@@ -1245,20 +1271,21 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) else None text_encodings = self.get_text_encodings(text) if exists(text) else None
img = None img = None
for unet, image_size in tqdm(zip(self.unets, self.image_sizes)): for unet, image_size in tqdm(zip(self.unets, self.image_sizes)):
shape = (batch_size, channels, image_size, image_size) with self.one_unet_in_gpu(unet = unet):
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img) shape = (batch_size, channels, image_size, image_size)
img = self.p_sample_loop(unet, shape, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = img)
return img return img
def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None): def forward(self, image, text = None, image_embed = None, text_encodings = None, unet_number = None):
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
unet_number = default(unet_number, 1) unet_number = default(unet_number, 1)
assert 1 <= unet_number <= len(self.unets)
index = unet_number - 1 unet = self.get_unet(unet_number)
unet = self.unets[index]
target_image_size = self.image_sizes[index] target_image_size = self.image_sizes[unet_number - 1]
b, c, h, w, device, = *image.shape, image.device b, c, h, w, device, = *image.shape, image.device
@@ -1272,7 +1299,7 @@ class Decoder(nn.Module):
text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None text_encodings = self.get_text_encodings(text) if exists(text) and not exists(text_encodings) else None
lowres_cond_img = image if index > 0 else None lowres_cond_img = image if unet_number > 1 else None
ddpm_image = resize_image_to(image, target_image_size) ddpm_image = resize_image_to(image, target_image_size)
return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img) return self.p_losses(unet, ddpm_image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.27', version = '0.0.31',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',