Compare commits

..

10 Commits

5 changed files with 416 additions and 31 deletions

View File

@@ -760,7 +760,7 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1,
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)
@@ -811,15 +817,16 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [x] bring in tools to train vqgan-vae
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
- [ ] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer
- [ ] train on a toy task, offer in colab
- [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder
- [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference
- [ ] bring in tools to train vqgan-vae
- [ ] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet)
## Citations

View File

@@ -922,6 +922,7 @@ class ConvNextBlock(nn.Module):
dim_out,
*,
cond_dim = None,
time_cond_dim = None,
mult = 2,
norm = True
):
@@ -940,6 +941,14 @@ class ConvNextBlock(nn.Module):
)
)
self.time_mlp = None
if exists(time_cond_dim):
self.time_mlp = nn.Sequential(
nn.GELU(),
nn.Linear(time_cond_dim, dim)
)
self.ds_conv = nn.Conv2d(dim, dim, 7, padding = 3, groups = dim)
inner_dim = int(dim_out * mult)
@@ -952,9 +961,13 @@ class ConvNextBlock(nn.Module):
self.res_conv = nn.Conv2d(dim, dim_out, 1) if need_projection else nn.Identity()
def forward(self, x, cond = None):
def forward(self, x, cond = None, time = None):
h = self.ds_conv(x)
if exists(time) and exists(self.time_mlp):
t = self.time_mlp(time)
h = rearrange(t, 'b c -> b c 1 1') + h
if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h
@@ -1059,6 +1072,8 @@ class Unet(nn.Module):
cond_on_text_encodings = False,
max_text_len = 256,
cond_on_image_embeds = False,
init_dim = None,
init_conv_kernel_size = 7
):
super().__init__()
# save locals to take care of some hyperparameters for cascading DDPM
@@ -1076,22 +1091,34 @@ class Unet(nn.Module):
self.channels = channels
init_channels = channels if not lowres_cond else channels * 2 # in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
init_dim = default(init_dim, dim // 2)
dims = [init_channels, *map(lambda m: dim * m, dim_mults)]
assert (init_conv_kernel_size % 2) == 1
self.init_conv = nn.Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
in_out = list(zip(dims[:-1], dims[1:]))
# time, image embeddings, and optional text encoding
cond_dim = default(cond_dim, dim)
time_cond_dim = dim * 4
self.time_mlp = nn.Sequential(
self.to_time_hiddens = nn.Sequential(
SinusoidalPosEmb(dim),
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, cond_dim * num_time_tokens),
nn.Linear(dim, time_cond_dim),
nn.GELU()
)
self.to_time_tokens = nn.Sequential(
nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
Rearrange('b (r d) -> b r d', r = num_time_tokens)
)
self.to_time_cond = nn.Sequential(
nn.Linear(time_cond_dim, time_cond_dim)
)
self.image_to_cond = nn.Sequential(
nn.Linear(image_embed_dim, cond_dim * num_image_tokens),
Rearrange('b (n d) -> b n d', n = num_image_tokens)
@@ -1133,26 +1160,26 @@ class Unet(nn.Module):
layer_cond_dim = cond_dim if not is_first else None
self.downs.append(nn.ModuleList([
ConvNextBlock(dim_in, dim_out, norm = ind != 0),
ConvNextBlock(dim_in, dim_out, time_cond_dim = time_cond_dim, norm = ind != 0),
Residual(GridAttention(dim_out, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim),
ConvNextBlock(dim_out, dim_out, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Downsample(dim_out) if not is_last else nn.Identity()
]))
mid_dim = dims[-1]
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block1 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
self.mid_attn = EinopsToAndFrom('b c h w', 'b (h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim)
self.mid_block2 = ConvNextBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim)
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
is_last = ind >= (num_resolutions - 2)
layer_cond_dim = cond_dim if not is_last else None
self.ups.append(nn.ModuleList([
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_out * 2, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Residual(GridAttention(dim_in, window_size = sparse_attn_window, **attn_kwargs)) if sparse_attn else nn.Identity(),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim),
ConvNextBlock(dim_in, dim_in, cond_dim = layer_cond_dim, time_cond_dim = time_cond_dim),
Upsample(dim_in)
]))
@@ -1214,9 +1241,16 @@ class Unet(nn.Module):
if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
# initial convolution
x = self.init_conv(x)
# time conditioning
time_tokens = self.time_mlp(time)
time_hiddens = self.to_time_hiddens(time)
time_tokens = self.to_time_tokens(time_hiddens)
t = self.to_time_cond(time_hiddens)
# conditional dropout
@@ -1283,24 +1317,24 @@ class Unet(nn.Module):
hiddens = []
for convnext, sparse_attn, convnext2, downsample in self.downs:
x = convnext(x, c)
x = convnext(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = convnext2(x, c, t)
hiddens.append(x)
x = downsample(x)
x = self.mid_block1(x, mid_c)
x = self.mid_block1(x, mid_c, t)
if exists(self.mid_attn):
x = self.mid_attn(x)
x = self.mid_block2(x, mid_c)
x = self.mid_block2(x, mid_c, t)
for convnext, sparse_attn, convnext2, upsample in self.ups:
x = torch.cat((x, hiddens.pop()), dim=1)
x = convnext(x, c)
x = convnext(x, c, t)
x = sparse_attn(x)
x = convnext2(x, c)
x = convnext2(x, c, t)
x = upsample(x)
return self.final_conv(x)
@@ -1540,7 +1574,13 @@ class Decoder(BaseGaussianDiffusion):
@torch.no_grad()
@eval_decorator
def sample(self, image_embed, text = None, cond_scale = 1.):
def sample(
self,
image_embed,
text = None,
cond_scale = 1.,
stop_at_unet_number = None
):
batch_size = image_embed.shape[0]
text_encodings = text_mask = None
@@ -1552,7 +1592,7 @@ class Decoder(BaseGaussianDiffusion):
img = None
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
@@ -1584,6 +1624,9 @@ class Decoder(BaseGaussianDiffusion):
img = vae.decode(img)
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
break
return img
def forward(

View File

@@ -3,12 +3,19 @@ from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@@ -89,6 +96,10 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
@@ -106,24 +117,82 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([])
for ind, unet in enumerate(self.decoder.unets):
optimizer = get_optimizer(unet.parameters(), **kwargs)
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
optimizer.step()
scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs)
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)

View File

@@ -0,0 +1,266 @@
from math import sqrt
import copy
from random import choice
from pathlib import Path
from shutil import rmtree
import torch
from torch import nn
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.utils import make_grid, save_image
from einops import rearrange
from dalle2_pytorch.train import EMA
from dalle2_pytorch.vqgan_vae import VQGanVAE
from dalle2_pytorch.optimizer import get_optimizer
# helpers
def exists(val):
return val is not None
def noop(*args, **kwargs):
pass
def cycle(dl):
while True:
for data in dl:
yield data
def cast_tuple(t):
return t if isinstance(t, (tuple, list)) else (t,)
def yes_or_no(question):
answer = input(f'{question} (y/n) ')
return answer.lower() in ('yes', 'y')
def accum_log(log, new_logs):
for key, new_value in new_logs.items():
old_value = log.get(key, 0.)
log[key] = old_value + new_value
return log
# classes
class ImageDataset(Dataset):
def __init__(
self,
folder,
image_size,
exts = ['jpg', 'jpeg', 'png']
):
super().__init__()
self.folder = folder
self.image_size = image_size
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
print(f'{len(self.paths)} training samples found at {folder}')
self.transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize(image_size),
T.RandomHorizontalFlip(),
T.CenterCrop(image_size),
T.ToTensor()
])
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
img = Image.open(path)
return self.transform(img)
# main trainer class
class VQGanVAETrainer(nn.Module):
def __init__(
self,
vae,
*,
num_train_steps,
lr,
batch_size,
folder,
grad_accum_every,
wd = 0.,
save_results_every = 100,
save_model_every = 1000,
results_folder = './results',
valid_frac = 0.05,
random_split_seed = 42,
ema_beta = 0.995,
ema_update_after_step = 2000,
ema_update_every = 10,
apply_grad_penalty_every = 4,
):
super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
image_size = vae.image_size
self.vae = vae
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
self.register_buffer('steps', torch.Tensor([0]))
self.num_train_steps = num_train_steps
self.batch_size = batch_size
self.grad_accum_every = grad_accum_every
all_parameters = set(vae.parameters())
discr_parameters = set(vae.discr.parameters())
vae_parameters = all_parameters - discr_parameters
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
# create dataset
self.ds = ImageDataset(folder, image_size = image_size)
# split for validation
if valid_frac > 0:
train_size = int((1 - valid_frac) * len(self.ds))
valid_size = len(self.ds) - train_size
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
else:
self.valid_ds = self.ds
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
# dataloader
self.dl = cycle(DataLoader(
self.ds,
batch_size = batch_size,
shuffle = True
))
self.valid_dl = cycle(DataLoader(
self.valid_ds,
batch_size = batch_size,
shuffle = True
))
self.save_model_every = save_model_every
self.save_results_every = save_results_every
self.apply_grad_penalty_every = apply_grad_penalty_every
self.results_folder = Path(results_folder)
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
rmtree(str(self.results_folder))
self.results_folder.mkdir(parents = True, exist_ok = True)
def train_step(self):
device = next(self.vae.parameters()).device
steps = int(self.steps.item())
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
self.vae.train()
# logs
logs = {}
# update vae (generator)
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(
img,
return_loss = True,
apply_grad_penalty = apply_grad_penalty
)
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.optim.step()
self.optim.zero_grad()
# update discriminator
if exists(self.vae.discr):
discr_loss = 0
for _ in range(self.grad_accum_every):
img = next(self.dl)
img = img.to(device)
loss = self.vae(img, return_discr_loss = True)
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward()
self.discr_optim.step()
self.discr_optim.zero_grad()
# log
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
# update exponential moving averaged generator
self.ema_vae.update()
# sample results every so often
if not (steps % self.save_results_every):
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
model.eval()
imgs = next(self.dl)
imgs = imgs.to(device)
recons = model(imgs)
nrows = int(sqrt(self.batch_size))
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
logs['reconstructions'] = grid
save_image(grid, str(self.results_folder / f'{filename}.png'))
print(f'{steps}: saving to {str(self.results_folder)}')
# save model every so often
if not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
torch.save(state_dict, model_path)
ema_state_dict = self.ema_vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
torch.save(ema_state_dict, model_path)
print(f'{steps}: saving model to {str(self.results_folder)}')
self.steps += 1
return logs
def train(self, log_fn = noop):
device = next(self.vae.parameters()).device
while self.steps < self.num_train_steps:
logs = self.train_step()
log_fn(logs)
print('training complete')

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.77',
version = '0.0.85',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',