From 740d644050687b1b09f2e22fce056b8ad3652bd5 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 6 May 2022 08:06:28 -0700 Subject: [PATCH] add diffusion prior trainer, which automatically takes care of the exponential moving average (training and sampling), as well as mixed precision, gradient clipping --- README.md | 62 +++++++++++++++++++++++ dalle2_pytorch/__init__.py | 2 +- dalle2_pytorch/dalle2_pytorch.py | 12 +++++ dalle2_pytorch/train.py | 85 +++++++++++++++++++++++++++++++- setup.py | 2 +- 5 files changed, 159 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 83a7f31..090a255 100644 --- a/README.md +++ b/README.md @@ -786,6 +786,68 @@ mock_image_embed = torch.randn(4, 512).cuda() images = decoder_trainer.sample(mock_image_embed, text = text) # (4, 3, 256, 256) ``` +### Diffusion Prior Training + +Similarly, one can use the `DiffusionPriorTrainer` to automatically instantiate and keep track of an exponential moving averaged prior. + +```python +import torch +from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, CLIP + +clip = CLIP( + dim_text = 512, + dim_image = 512, + dim_latent = 512, + num_text_tokens = 49408, + text_enc_depth = 6, + text_seq_len = 256, + text_heads = 8, + visual_enc_depth = 6, + visual_image_size = 256, + visual_patch_size = 32, + visual_heads = 8 +).cuda() + +# mock data + +text = torch.randint(0, 49408, (4, 256)).cuda() +images = torch.randn(4, 3, 256, 256).cuda() + +# prior networks (with transformer) + +prior_network = DiffusionPriorNetwork( + dim = 512, + depth = 6, + dim_head = 64, + heads = 8 +).cuda() + +diffusion_prior = DiffusionPrior( + net = prior_network, + clip = clip, + timesteps = 100, + cond_drop_prob = 0.2 +).cuda() + +diffusion_prior_trainer = DiffusionPriorTrainer( + diffusion_prior, + lr = 3e-4, + wd = 1e-2, + ema_beta = 0.99, + ema_update_after_step = 1000, + ema_update_every = 10, +) + +loss = diffusion_prior_trainer(text, images) +loss.backward() +diffusion_prior_trainer.update() # this will update the optimizer as well as the exponential moving averaged diffusion prior + +# after much of the above three lines in a loop +# you can sample from the exponential moving average of the diffusion prior identically to how you do so for DiffusionPrior + +image_embeds = diffusion_prior_trainer.sample(text) # (4, 512) - exponential moving averaged image embeddings +``` + ### Decoder Dataloaders In order to make loading data simple and efficient, we include some general dataloaders that can be used to train portions of the network. diff --git a/dalle2_pytorch/__init__.py b/dalle2_pytorch/__init__.py index c3ab2d9..60987bd 100644 --- a/dalle2_pytorch/__init__.py +++ b/dalle2_pytorch/__init__.py @@ -1,6 +1,6 @@ from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter -from dalle2_pytorch.train import DecoderTrainer +from dalle2_pytorch.train import DecoderTrainer, DiffusionPriorTrainer from dalle2_pytorch.vqgan_vae import VQGanVAE from x_clip import CLIP diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 4c9bc0c..197c855 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -845,6 +845,18 @@ class DiffusionPrior(BaseGaussianDiffusion): loss = self.loss_fn(pred, target) return loss + @torch.inference_mode() + @eval_decorator + def sample_batch_size(self, batch_size, text_cond): + device = self.betas.device + shape = (batch_size, self.image_embed_dim) + + img = torch.randn(shape, device = device) + + for i in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + img = self.p_sample(img, torch.full((batch_size,), i, device = device, dtype = torch.long), text_cond = text_cond) + return img + @torch.inference_mode() @eval_decorator def sample(self, text, num_samples_per_batch = 2): diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index ddb0732..a346f6b 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -5,7 +5,7 @@ import torch from torch import nn from torch.cuda.amp import autocast, GradScaler -from dalle2_pytorch.dalle2_pytorch import Decoder +from dalle2_pytorch.dalle2_pytorch import Decoder, DiffusionPrior from dalle2_pytorch.optimizer import get_optimizer # helper functions @@ -89,7 +89,88 @@ class EMA(nn.Module): def __call__(self, *args, **kwargs): return self.ema_model(*args, **kwargs) -# trainers +# diffusion prior trainer + +class DiffusionPriorTrainer(nn.Module): + def __init__( + self, + diffusion_prior, + use_ema = True, + lr = 3e-4, + wd = 1e-2, + max_grad_norm = None, + amp = False, + **kwargs + ): + super().__init__() + assert isinstance(diffusion_prior, DiffusionPrior) + ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) + + self.diffusion_prior = diffusion_prior + + # exponential moving average + + self.use_ema = use_ema + + if use_ema: + has_lazy_linear = any([type(module) == nn.LazyLinear for module in diffusion_prior.modules()]) + assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average' + + if self.use_ema: + self.ema_diffusion_prior = EMA(diffusion_prior, **ema_kwargs) + + # optimizer and mixed precision stuff + + self.amp = amp + + self.scaler = GradScaler(enabled = amp) + + self.optimizer = get_optimizer( + diffusion_prior.parameters(), + lr = lr, + wd = wd, + **kwargs + ) + + # gradient clipping if needed + + self.max_grad_norm = max_grad_norm + + def update(self): + if exists(self.max_grad_norm): + self.scaler.unscale_(self.optimizer) + nn.utils.clip_grad_norm_(self.diffusion_prior.parameters(), self.max_grad_norm) + + self.scaler.step(self.optimizer) + self.scaler.update() + self.optimizer.zero_grad() + + if self.use_ema: + self.ema_diffusion_prior.update() + + @torch.inference_mode() + def p_sample_loop(self, *args, **kwargs): + return self.ema_diffusion_prior.ema_model.p_sample_loop(*args, **kwargs) + + @torch.inference_mode() + def sample(self, *args, **kwargs): + return self.ema_diffusion_prior.ema_model.sample(*args, **kwargs) + + @torch.inference_mode() + def sample_batch_size(self, *args, **kwargs): + return self.ema_diffusion_prior.ema_model.sample_batch_size(*args, **kwargs) + + def forward( + self, + *args, + divisor = 1, + **kwargs + ): + with autocast(enabled = self.amp): + loss = self.diffusion_prior(*args, **kwargs) + return self.scaler.scale(loss / divisor) + +# decoder trainer class DecoderTrainer(nn.Module): def __init__( diff --git a/setup.py b/setup.py index 7246ad2..a22aec5 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.107', + version = '0.0.108', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',