take care of mixed precision, and make gradient accumulation do-able externally

This commit is contained in:
Phil Wang
2022-04-30 12:27:24 -07:00
parent 5fff22834e
commit a2ef69af66
3 changed files with 22 additions and 4 deletions

View File

@@ -811,7 +811,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms - [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion - [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in - [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer - [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet - [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network - [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs - [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs

View File

@@ -3,6 +3,7 @@ from functools import partial
import torch import torch
from torch import nn from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer from dalle2_pytorch.optimizer import get_optimizer
@@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module):
lr = 3e-4, lr = 3e-4,
wd = 1e-2, wd = 1e-2,
max_grad_norm = None, max_grad_norm = None,
amp = False,
**kwargs **kwargs
): ):
super().__init__() super().__init__()
@@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([]) self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay # be able to finely customize learning rate, weight decay
# per unet # per unet
@@ -133,10 +137,19 @@ class DecoderTrainer(nn.Module):
if self.use_ema: if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs)) self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed # gradient clipping if needed
self.max_grad_norm = max_grad_norm self.max_grad_norm = max_grad_norm
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number): def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets assert 1 <= unet_number <= self.num_unets
index = unet_number - 1 index = unet_number - 1
@@ -146,7 +159,10 @@ class DecoderTrainer(nn.Module):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm) nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}') optimizer = getattr(self, f'optim{index}')
optimizer.step() scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
if self.use_ema: if self.use_ema:
@@ -154,4 +170,6 @@ class DecoderTrainer(nn.Module):
ema_unet.update() ema_unet.update()
def forward(self, x, *, unet_number, **kwargs): def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs) with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss, unet_number = unet_number)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.78', version = '0.0.79',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',