From a2ef69af661f4a7a398e3fee1449817b27b5e142 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 30 Apr 2022 12:27:24 -0700 Subject: [PATCH] take care of mixed precision, and make gradient accumulation do-able externally --- README.md | 2 +- dalle2_pytorch/train.py | 22 ++++++++++++++++++++-- setup.py | 2 +- 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index ba5cfe5..a7686d6 100644 --- a/README.md +++ b/README.md @@ -811,7 +811,7 @@ Once built, images will be saved to the same directory the command is invoked - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms - [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion - [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in -- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer +- [x] take care of mixed precision as well as gradient accumulation within decoder trainer - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs diff --git a/dalle2_pytorch/train.py b/dalle2_pytorch/train.py index 4316204..a8c628c 100644 --- a/dalle2_pytorch/train.py +++ b/dalle2_pytorch/train.py @@ -3,6 +3,7 @@ from functools import partial import torch from torch import nn +from torch.cuda.amp import autocast, GradScaler from dalle2_pytorch.dalle2_pytorch import Decoder from dalle2_pytorch.optimizer import get_optimizer @@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module): lr = 3e-4, wd = 1e-2, max_grad_norm = None, + amp = False, **kwargs ): super().__init__() @@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module): self.ema_unets = nn.ModuleList([]) + self.amp = amp + # be able to finely customize learning rate, weight decay # per unet @@ -133,10 +137,19 @@ class DecoderTrainer(nn.Module): if self.use_ema: self.ema_unets.append(EMA(unet, **ema_kwargs)) + scaler = GradScaler(enabled = amp) + setattr(self, f'scaler{ind}', scaler) + # gradient clipping if needed self.max_grad_norm = max_grad_norm + def scale(self, loss, *, unet_number): + assert 1 <= unet_number <= self.num_unets + index = unet_number - 1 + scaler = getattr(self, f'scaler{index}') + return scaler.scale(loss) + def update(self, unet_number): assert 1 <= unet_number <= self.num_unets index = unet_number - 1 @@ -146,7 +159,10 @@ class DecoderTrainer(nn.Module): nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) optimizer = getattr(self, f'optim{index}') - optimizer.step() + scaler = getattr(self, f'scaler{index}') + + scaler.step(optimizer) + scaler.update() optimizer.zero_grad() if self.use_ema: @@ -154,4 +170,6 @@ class DecoderTrainer(nn.Module): ema_unet.update() def forward(self, x, *, unet_number, **kwargs): - return self.decoder(x, unet_number = unet_number, **kwargs) + with autocast(enabled = self.amp): + loss = self.decoder(x, unet_number = unet_number, **kwargs) + return self.scale(loss, unet_number = unet_number) diff --git a/setup.py b/setup.py index 7a3bfbe..fe6a849 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ setup( 'dream = dalle2_pytorch.cli:dream' ], }, - version = '0.0.78', + version = '0.0.79', license='MIT', description = 'DALL-E 2', author = 'Phil Wang',