From 5bfbccda22814a6a7dc5fc9e8f774bbbed9693a7 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 1 May 2022 08:09:15 -0700 Subject: [PATCH] port over vqgan vae trainer --- README.md | 2 +- dalle2_pytorch/train_vqgan_vae.py | 266 ++++++++++++++++++++++++++++++ 2 files changed, 267 insertions(+), 1 deletion(-) create mode 100644 dalle2_pytorch/train_vqgan_vae.py diff --git a/README.md b/README.md index 0bb35e7..18c7986 100644 --- a/README.md +++ b/README.md @@ -819,13 +819,13 @@ Once built, images will be saved to the same directory the command is invoked - [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [x] take care of mixed precision as well as gradient accumulation within decoder trainer - [x] just take care of the training for the decoder in a wrapper class, as each unet in the cascade will need its own optimizer +- [x] bring in tools to train vqgan-vae - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] train on a toy task, offer in colab - [ ] think about how best to design a declarative training config that handles preencoding for prior and training of multiple networks in decoder - [ ] extend diffusion head to use diffusion-gan (potentially using lightweight-gan) to speed up inference -- [ ] bring in tools to train vqgan-vae - [ ] add convnext backbone for vqgan-vae (in addition to vit [vit-vqgan] + resnet) ## Citations diff --git a/dalle2_pytorch/train_vqgan_vae.py b/dalle2_pytorch/train_vqgan_vae.py new file mode 100644 index 0000000..cab99c3 --- /dev/null +++ b/dalle2_pytorch/train_vqgan_vae.py @@ -0,0 +1,266 @@ +from math import sqrt +import copy +from random import choice +from pathlib import Path +from shutil import rmtree + +import torch +from torch import nn + +from PIL import Image +from torchvision.datasets import ImageFolder +import torchvision.transforms as T +from torch.utils.data import Dataset, DataLoader, random_split +from torchvision.utils import make_grid, save_image + +from einops import rearrange + +from dalle2_pytorch.train import EMA +from dalle2_pytorch.vqgan_vae import VQGanVAE +from dalle2_pytorch.optimizer import get_optimizer + +# helpers + +def exists(val): + return val is not None + +def noop(*args, **kwargs): + pass + +def cycle(dl): + while True: + for data in dl: + yield data + +def cast_tuple(t): + return t if isinstance(t, (tuple, list)) else (t,) + +def yes_or_no(question): + answer = input(f'{question} (y/n) ') + return answer.lower() in ('yes', 'y') + +def accum_log(log, new_logs): + for key, new_value in new_logs.items(): + old_value = log.get(key, 0.) + log[key] = old_value + new_value + return log + +# classes + +class ImageDataset(Dataset): + def __init__( + self, + folder, + image_size, + exts = ['jpg', 'jpeg', 'png'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] + + print(f'{len(self.paths)} training samples found at {folder}') + + self.transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize(image_size), + T.RandomHorizontalFlip(), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + +# main trainer class + +class VQGanVAETrainer(nn.Module): + def __init__( + self, + vae, + *, + num_train_steps, + lr, + batch_size, + folder, + grad_accum_every, + wd = 0., + save_results_every = 100, + save_model_every = 1000, + results_folder = './results', + valid_frac = 0.05, + random_split_seed = 42, + ema_beta = 0.995, + ema_update_after_step = 2000, + ema_update_every = 10, + apply_grad_penalty_every = 4, + ): + super().__init__() + assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' + image_size = vae.image_size + + self.vae = vae + self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every) + + self.register_buffer('steps', torch.Tensor([0])) + + self.num_train_steps = num_train_steps + self.batch_size = batch_size + self.grad_accum_every = grad_accum_every + + all_parameters = set(vae.parameters()) + discr_parameters = set(vae.discr.parameters()) + vae_parameters = all_parameters - discr_parameters + + self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) + self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) + + # create dataset + + self.ds = ImageDataset(folder, image_size = image_size) + + # split for validation + + if valid_frac > 0: + train_size = int((1 - valid_frac) * len(self.ds)) + valid_size = len(self.ds) - train_size + self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed)) + print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples') + else: + self.valid_ds = self.ds + print(f'training with shared training and valid dataset of {len(self.ds)} samples') + + # dataloader + + self.dl = cycle(DataLoader( + self.ds, + batch_size = batch_size, + shuffle = True + )) + + self.valid_dl = cycle(DataLoader( + self.valid_ds, + batch_size = batch_size, + shuffle = True + )) + + self.save_model_every = save_model_every + self.save_results_every = save_results_every + + self.apply_grad_penalty_every = apply_grad_penalty_every + + self.results_folder = Path(results_folder) + + if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'): + rmtree(str(self.results_folder)) + + self.results_folder.mkdir(parents = True, exist_ok = True) + + def train_step(self): + device = next(self.vae.parameters()).device + steps = int(self.steps.item()) + apply_grad_penalty = not (steps % self.apply_grad_penalty_every) + + self.vae.train() + + # logs + + logs = {} + + # update vae (generator) + + for _ in range(self.grad_accum_every): + img = next(self.dl) + img = img.to(device) + + loss = self.vae( + img, + return_loss = True, + apply_grad_penalty = apply_grad_penalty + ) + + accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) + + (loss / self.grad_accum_every).backward() + + self.optim.step() + self.optim.zero_grad() + + + # update discriminator + + if exists(self.vae.discr): + discr_loss = 0 + for _ in range(self.grad_accum_every): + img = next(self.dl) + img = img.to(device) + + loss = self.vae(img, return_discr_loss = True) + accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) + + (loss / self.grad_accum_every).backward() + + self.discr_optim.step() + self.discr_optim.zero_grad() + + # log + + print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}") + + # update exponential moving averaged generator + + self.ema_vae.update() + + # sample results every so often + + if not (steps % self.save_results_every): + for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))): + model.eval() + + imgs = next(self.dl) + imgs = imgs.to(device) + + recons = model(imgs) + nrows = int(sqrt(self.batch_size)) + + imgs_and_recons = torch.stack((imgs, recons), dim = 0) + imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...') + + imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.) + grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1)) + + logs['reconstructions'] = grid + + save_image(grid, str(self.results_folder / f'{filename}.png')) + + print(f'{steps}: saving to {str(self.results_folder)}') + + # save model every so often + + if not (steps % self.save_model_every): + state_dict = self.vae.state_dict() + model_path = str(self.results_folder / f'vae.{steps}.pt') + torch.save(state_dict, model_path) + + ema_state_dict = self.ema_vae.state_dict() + model_path = str(self.results_folder / f'vae.{steps}.ema.pt') + torch.save(ema_state_dict, model_path) + + print(f'{steps}: saving model to {str(self.results_folder)}') + + self.steps += 1 + return logs + + def train(self, log_fn = noop): + device = next(self.vae.parameters()).device + + while self.steps < self.num_train_steps: + logs = self.train_step() + log_fn(logs) + + print('training complete')