make sure vqgan-vae trainer supports mixed precision

This commit is contained in:
Phil Wang
2022-05-06 10:44:16 -07:00
parent 28e944f328
commit 3676ef4d49
2 changed files with 29 additions and 18 deletions

View File

@@ -3,14 +3,15 @@ import copy
from random import choice from random import choice
from pathlib import Path from pathlib import Path
from shutil import rmtree from shutil import rmtree
from PIL import Image
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from PIL import Image
from torchvision.datasets import ImageFolder
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid, save_image from torchvision.utils import make_grid, save_image
from einops import rearrange from einops import rearrange
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
ema_update_after_step = 2000, ema_update_after_step = 2000,
ema_update_every = 10, ema_update_every = 10,
apply_grad_penalty_every = 4, apply_grad_penalty_every = 4,
amp = False
): ):
super().__init__() super().__init__()
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.discr_scaler = GradScaler(enabled = amp)
# create dataset # create dataset
self.ds = ImageDataset(folder, image_size = image_size) self.ds = ImageDataset(folder, image_size = image_size)
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
with autocast(enabled = self.amp):
loss = self.vae( loss = self.vae(
img, img,
return_loss = True, return_loss = True,
apply_grad_penalty = apply_grad_penalty apply_grad_penalty = apply_grad_penalty
) )
self.scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.scaler.step(self.optim)
self.scaler.update()
self.optim.step()
self.optim.zero_grad() self.optim.zero_grad()
# update discriminator # update discriminator
if exists(self.vae.discr): if exists(self.vae.discr):
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
img = next(self.dl) img = next(self.dl)
img = img.to(device) img = img.to(device)
with autocast(enabled = self.amp):
loss = self.vae(img, return_discr_loss = True) loss = self.vae(img, return_discr_loss = True)
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
(loss / self.grad_accum_every).backward() self.discr_scaler.step(self.discr_optim)
self.discr_scaler.update()
self.discr_optim.step()
self.discr_optim.zero_grad() self.discr_optim.zero_grad()
# log # log

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.1.4', version = '0.1.5',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',