Compare commits

..

2 Commits

3 changed files with 52 additions and 6 deletions

View File

@@ -811,7 +811,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs

View File

@@ -3,12 +3,19 @@ from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
# helper functions
def exists(val):
return val is not None
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else ((val,) * length)
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
@@ -89,6 +96,10 @@ class DecoderTrainer(nn.Module):
self,
decoder,
use_ema = True,
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
@@ -106,19 +117,52 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([])
for ind, unet in enumerate(self.decoder.unets):
optimizer = get_optimizer(unet.parameters(), **kwargs)
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
optimizer = get_optimizer(
unet.parameters(),
lr = unet_lr,
wd = unet_wd,
**kwargs
)
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
unet = self.decoder.unets[index]
if exists(self.max_grad_norm):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
optimizer.step()
scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
@@ -126,4 +170,6 @@ class DecoderTrainer(nn.Module):
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs)
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss, unet_number = unet_number)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.77',
version = '0.0.79',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',