add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping

This commit is contained in:
Phil Wang
2022-05-06 08:11:09 -07:00
parent 878b555ef7
commit 98df1ba51e
5 changed files with 154 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
from dalle2_pytorch.train import DecoderTrainer
from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer
from dalle2_pytorch.vqgan_vae import VQGanVAE
from x_clip import CLIP

View File

@@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion):
loss = self.loss_fn(pred, target)
return loss
@torch.inference_mode()
@eval_decorator
def sample_batch_size(self, batch_size, text_cond):
device = self.betas.device
shape = (batch_size, self.image_embed_dim)
img = torch.randn(shape, device = device)
for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond)
return img
@torch.inference_mode()
@eval_decorator
def sample(self, text, num_samples_per_batch = 2):

View File

@@ -5,7 +5,7 @@ import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
@@ -89,7 +89,83 @@ class EMA(nn.Module):
def __call__(self, *args, **kwargs):
return self.ema_model(*args, **kwargs)
# trainers
# diffusion prior trainer
class DiffusionPriorTrainer(nn.Module):
def __init__(
self,
diffusion_prior,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
assert isinstance(diffusion_prior, DiffusionPrior)
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
self.diffusion_prior = diffusion_prior
# exponential moving average
self.use_ema = use_ema
if self.use_ema:
self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs)
# optimizer and mixed precision stuff
self.amp = amp
self.scaler = GradScaler(enabled = amp)
self.optimizer = get_optimizer(
diffusion_prior.parameters(),
lr = lr,
wd = wd,
**kwargs
)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def update(self):
if exists(self.max_grad_norm):
self.scaler.unscale_(self.optimizer)
nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm)
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
if self.use_ema:
self.ema_diffusion_prior.update()
@torch.inference_mode()
def p_sample_loop(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs)
@torch.inference_mode()
def sample(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs)
@torch.inference_mode()
def sample_batch_size(self, *args, **kwargs):
return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs)
def forward(
self,
*args,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.diffusion_prior(*args, **kwargs)
return self.scaler.scale(loss / divisor)
# decoder trainer
class DecoderTrainer(nn.Module):
def __init__(