mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
port over vqgan vae trainer
This commit is contained in:
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
266
dalle2_pytorch/train_vqgan_vae.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from math import sqrt
|
||||
import copy
|
||||
from random import choice
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from PIL import Image
|
||||
from torchvision.datasets import ImageFolder
|
||||
import torchvision.transforms as T
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from einops import rearrange
|
||||
|
||||
from dalle2_pytorch.train import EMA
|
||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||
from dalle2_pytorch.optimizer import get_optimizer
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data in dl:
|
||||
yield data
|
||||
|
||||
def cast_tuple(t):
|
||||
return t if isinstance(t, (tuple, list)) else (t,)
|
||||
|
||||
def yes_or_no(question):
|
||||
answer = input(f'{question} (y/n) ')
|
||||
return answer.lower() in ('yes', 'y')
|
||||
|
||||
def accum_log(log, new_logs):
|
||||
for key, new_value in new_logs.items():
|
||||
old_value = log.get(key, 0.)
|
||||
log[key] = old_value + new_value
|
||||
return log
|
||||
|
||||
# classes
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
folder,
|
||||
image_size,
|
||||
exts = ['jpg', 'jpeg', 'png']
|
||||
):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
|
||||
|
||||
print(f'{len(self.paths)} training samples found at {folder}')
|
||||
|
||||
self.transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize(image_size),
|
||||
T.RandomHorizontalFlip(),
|
||||
T.CenterCrop(image_size),
|
||||
T.ToTensor()
|
||||
])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
# main trainer class
|
||||
|
||||
class VQGanVAETrainer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
*,
|
||||
num_train_steps,
|
||||
lr,
|
||||
batch_size,
|
||||
folder,
|
||||
grad_accum_every,
|
||||
wd = 0.,
|
||||
save_results_every = 100,
|
||||
save_model_every = 1000,
|
||||
results_folder = './results',
|
||||
valid_frac = 0.05,
|
||||
random_split_seed = 42,
|
||||
ema_beta = 0.995,
|
||||
ema_update_after_step = 2000,
|
||||
ema_update_every = 10,
|
||||
apply_grad_penalty_every = 4,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||
image_size = vae.image_size
|
||||
|
||||
self.vae = vae
|
||||
self.ema_vae = EMA(vae, update_after_step = ema_update_after_step, update_every = ema_update_every)
|
||||
|
||||
self.register_buffer('steps', torch.Tensor([0]))
|
||||
|
||||
self.num_train_steps = num_train_steps
|
||||
self.batch_size = batch_size
|
||||
self.grad_accum_every = grad_accum_every
|
||||
|
||||
all_parameters = set(vae.parameters())
|
||||
discr_parameters = set(vae.discr.parameters())
|
||||
vae_parameters = all_parameters - discr_parameters
|
||||
|
||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||
|
||||
# create dataset
|
||||
|
||||
self.ds = ImageDataset(folder, image_size = image_size)
|
||||
|
||||
# split for validation
|
||||
|
||||
if valid_frac > 0:
|
||||
train_size = int((1 - valid_frac) * len(self.ds))
|
||||
valid_size = len(self.ds) - train_size
|
||||
self.ds, self.valid_ds = random_split(self.ds, [train_size, valid_size], generator = torch.Generator().manual_seed(random_split_seed))
|
||||
print(f'training with dataset of {len(self.ds)} samples and validating with randomly splitted {len(self.valid_ds)} samples')
|
||||
else:
|
||||
self.valid_ds = self.ds
|
||||
print(f'training with shared training and valid dataset of {len(self.ds)} samples')
|
||||
|
||||
# dataloader
|
||||
|
||||
self.dl = cycle(DataLoader(
|
||||
self.ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.valid_dl = cycle(DataLoader(
|
||||
self.valid_ds,
|
||||
batch_size = batch_size,
|
||||
shuffle = True
|
||||
))
|
||||
|
||||
self.save_model_every = save_model_every
|
||||
self.save_results_every = save_results_every
|
||||
|
||||
self.apply_grad_penalty_every = apply_grad_penalty_every
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
|
||||
if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
|
||||
rmtree(str(self.results_folder))
|
||||
|
||||
self.results_folder.mkdir(parents = True, exist_ok = True)
|
||||
|
||||
def train_step(self):
|
||||
device = next(self.vae.parameters()).device
|
||||
steps = int(self.steps.item())
|
||||
apply_grad_penalty = not (steps % self.apply_grad_penalty_every)
|
||||
|
||||
self.vae.train()
|
||||
|
||||
# logs
|
||||
|
||||
logs = {}
|
||||
|
||||
# update vae (generator)
|
||||
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(
|
||||
img,
|
||||
return_loss = True,
|
||||
apply_grad_penalty = apply_grad_penalty
|
||||
)
|
||||
|
||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.optim.step()
|
||||
self.optim.zero_grad()
|
||||
|
||||
|
||||
# update discriminator
|
||||
|
||||
if exists(self.vae.discr):
|
||||
discr_loss = 0
|
||||
for _ in range(self.grad_accum_every):
|
||||
img = next(self.dl)
|
||||
img = img.to(device)
|
||||
|
||||
loss = self.vae(img, return_discr_loss = True)
|
||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||
|
||||
(loss / self.grad_accum_every).backward()
|
||||
|
||||
self.discr_optim.step()
|
||||
self.discr_optim.zero_grad()
|
||||
|
||||
# log
|
||||
|
||||
print(f"{steps}: vae loss: {logs['loss']} - discr loss: {logs['discr_loss']}")
|
||||
|
||||
# update exponential moving averaged generator
|
||||
|
||||
self.ema_vae.update()
|
||||
|
||||
# sample results every so often
|
||||
|
||||
if not (steps % self.save_results_every):
|
||||
for model, filename in ((self.ema_vae.ema_model, f'{steps}.ema'), (self.vae, str(steps))):
|
||||
model.eval()
|
||||
|
||||
imgs = next(self.dl)
|
||||
imgs = imgs.to(device)
|
||||
|
||||
recons = model(imgs)
|
||||
nrows = int(sqrt(self.batch_size))
|
||||
|
||||
imgs_and_recons = torch.stack((imgs, recons), dim = 0)
|
||||
imgs_and_recons = rearrange(imgs_and_recons, 'r b ... -> (b r) ...')
|
||||
|
||||
imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0., 1.)
|
||||
grid = make_grid(imgs_and_recons, nrow = 2, normalize = True, value_range = (0, 1))
|
||||
|
||||
logs['reconstructions'] = grid
|
||||
|
||||
save_image(grid, str(self.results_folder / f'{filename}.png'))
|
||||
|
||||
print(f'{steps}: saving to {str(self.results_folder)}')
|
||||
|
||||
# save model every so often
|
||||
|
||||
if not (steps % self.save_model_every):
|
||||
state_dict = self.vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.pt')
|
||||
torch.save(state_dict, model_path)
|
||||
|
||||
ema_state_dict = self.ema_vae.state_dict()
|
||||
model_path = str(self.results_folder / f'vae.{steps}.ema.pt')
|
||||
torch.save(ema_state_dict, model_path)
|
||||
|
||||
print(f'{steps}: saving model to {str(self.results_folder)}')
|
||||
|
||||
self.steps += 1
|
||||
return logs
|
||||
|
||||
def train(self, log_fn = noop):
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
while self.steps < self.num_train_steps:
|
||||
logs = self.train_step()
|
||||
log_fn(logs)
|
||||
|
||||
print('training complete')
|
||||
Reference in New Issue
Block a user