mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure vqgan-vae trainer supports mixed precision
This commit is contained in:
@@ -3,14 +3,15 @@ import copy
|
|||||||
from random import choice
|
from random import choice
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree
|
from shutil import rmtree
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
from PIL import Image
|
|
||||||
from torchvision.datasets import ImageFolder
|
|
||||||
import torchvision.transforms as T
|
|
||||||
from torch.utils.data import Dataset, DataLoader, random_split
|
from torch.utils.data import Dataset, DataLoader, random_split
|
||||||
|
|
||||||
|
import torchvision.transforms as T
|
||||||
|
from torchvision.datasets import ImageFolder
|
||||||
from torchvision.utils import make_grid, save_image
|
from torchvision.utils import make_grid, save_image
|
||||||
|
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -99,6 +100,7 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
ema_update_after_step = 2000,
|
ema_update_after_step = 2000,
|
||||||
ema_update_every = 10,
|
ema_update_every = 10,
|
||||||
apply_grad_penalty_every = 4,
|
apply_grad_penalty_every = 4,
|
||||||
|
amp = False
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
assert isinstance(vae, VQGanVAE), 'vae must be instance of VQGanVAE'
|
||||||
@@ -120,6 +122,10 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
self.optim = get_optimizer(vae_parameters, lr = lr, wd = wd)
|
||||||
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
self.discr_optim = get_optimizer(discr_parameters, lr = lr, wd = wd)
|
||||||
|
|
||||||
|
self.amp = amp
|
||||||
|
self.scaler = GradScaler(enabled = amp)
|
||||||
|
self.discr_scaler = GradScaler(enabled = amp)
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
|
|
||||||
self.ds = ImageDataset(folder, image_size = image_size)
|
self.ds = ImageDataset(folder, image_size = image_size)
|
||||||
@@ -178,20 +184,22 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
img = next(self.dl)
|
img = next(self.dl)
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
|
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
loss = self.vae(
|
loss = self.vae(
|
||||||
img,
|
img,
|
||||||
return_loss = True,
|
return_loss = True,
|
||||||
apply_grad_penalty = apply_grad_penalty
|
apply_grad_penalty = apply_grad_penalty
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
self.scaler.scale(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
accum_log(logs, {'loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
(loss / self.grad_accum_every).backward()
|
self.scaler.step(self.optim)
|
||||||
|
self.scaler.update()
|
||||||
self.optim.step()
|
|
||||||
self.optim.zero_grad()
|
self.optim.zero_grad()
|
||||||
|
|
||||||
|
|
||||||
# update discriminator
|
# update discriminator
|
||||||
|
|
||||||
if exists(self.vae.discr):
|
if exists(self.vae.discr):
|
||||||
@@ -200,12 +208,15 @@ class VQGanVAETrainer(nn.Module):
|
|||||||
img = next(self.dl)
|
img = next(self.dl)
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
|
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
loss = self.vae(img, return_discr_loss = True)
|
loss = self.vae(img, return_discr_loss = True)
|
||||||
|
|
||||||
|
self.discr_scaler.scale(loss / self.grad_accum_every).backward()
|
||||||
|
|
||||||
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
accum_log(logs, {'discr_loss': loss.item() / self.grad_accum_every})
|
||||||
|
|
||||||
(loss / self.grad_accum_every).backward()
|
self.discr_scaler.step(self.discr_optim)
|
||||||
|
self.discr_scaler.update()
|
||||||
self.discr_optim.step()
|
|
||||||
self.discr_optim.zero_grad()
|
self.discr_optim.zero_grad()
|
||||||
|
|
||||||
# log
|
# log
|
||||||
|
|||||||
Reference in New Issue
Block a user