From 3676ef4d4928de012a2a3f5d18dce5258c368661 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 May 2022 10:44:16 -0700 Subject: [PATCH] make sure vqgan-vae trainer supports mixed precision --- dalle2_pytorch/train_vqgan_vae.py | 45 +++++++++++++++++++------------ setup.py | 2 +- 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/dalle2_pytorch/train_vqgan_vae.py b/dalle2_pytorch/train_vqgan_vae.py index cab99c3..cbb6f1f 100644 --- a/dalle2_pytorch/train_vqgan_vae.py +++ b/dalle2_pytorch/train_vqgan_vae.py @@ -3,14 +3,15 @@ import copy from random import choice from pathlib import Path from shutil import rmtree +from PIL import Image import torch from torch import nn - -from PIL import Image -from torchvision.datasets import ImageFolder -import torchvision.transforms as T +from torch.cuda.amp import autocast, GradScaler from torch.utils.data import Dataset, DataLoader, random_split + +import torchvision.transforms as T +from torchvision.datasets import ImageFolder from torchvision.utils import make_grid, save_image from einops import rearrange @@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module): ema_update_after_step = 2000, ema_update_every = 10, apply_grad_penalty_every = 4, + amp = False ): super().__init__() assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE' @@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module): self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd) self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd) + self.amp = amp + self.scaler = GradScaler(enabled = amp) + self.discr_scaler = GradScaler(enabled = amp) + # create dataset self.ds = ImageDataset(folder, image_size = image_size) @@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module): img = next(self.dl) img = img.to(device) - loss = self.vae( - img, - return_loss = True, - apply_grad_penalty = apply_grad_penalty - ) + with autocast(enabled = self.amp): + loss = self.vae( + img, + return_loss = True, + apply_grad_penalty = apply_grad_penalty + ) + + + self.scaler.scale(loss / self.grad_accum_every).backward() accum_log(logs, {'loss': loss.item() / self.grad_accum_every}) - (loss / self.grad_accum_every).backward() - - self.optim.step() + self.scaler.step(self.optim) + self.scaler.update() self.optim.zero_grad() - # update discriminator if exists(self.vae.discr): @@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module): img = next(self.dl) img = img.to(device) - loss = self.vae(img, return_discr_loss = True) + with autocast(enabled = self.amp): + loss = self.vae(img, return_discr_loss = True) + + self.discr_scaler.scale(loss / self.grad_accum_every).backward() + accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every}) - (loss / self.grad_accum_every).backward() - - self.discr_optim.step() + self.discr_scaler.step(self.discr_optim) + self.discr_scaler.update() self.discr_optim.zero_grad() # log diff --git a/setup.py b/setup.py index c981089..3613174 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.1.4', + version = '0.1.5', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',