Compare commits

...

3 Commits

3 changed files with 53 additions and 6 deletions

View File

@@ -760,7 +760,7 @@ decoder = Decoder(
unet = (unet1, unet2),
image_sizes = (128, 256),
clip = clip,
timesteps = 1,
timesteps = 1000,
condition_on_text_encodings = True
).cuda()
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
loss.backward()
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
# after much training
# you can sample from the exponentially moving averaged unets as so
mock_image_embed = torch.randn(4, 512).cuda()
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
```
## CLI (wip)
@@ -811,7 +817,7 @@ Once built, images will be saved to the same directory the command is invoked
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs

View File

@@ -3,6 +3,7 @@ from functools import partial
import torch
from torch import nn
from torch.cuda.amp import autocast, GradScaler
from dalle2_pytorch.dalle2_pytorch import Decoder
from dalle2_pytorch.optimizer import get_optimizer
@@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module):
lr = 3e-4,
wd = 1e-2,
max_grad_norm = None,
amp = False,
**kwargs
):
super().__init__()
@@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module):
self.ema_unets = nn.ModuleList([])
self.amp = amp
# be able to finely customize learning rate, weight decay
# per unet
@@ -133,10 +137,23 @@ class DecoderTrainer(nn.Module):
if self.use_ema:
self.ema_unets.append(EMA(unet, **ema_kwargs))
scaler = GradScaler(enabled = amp)
setattr(self, f'scaler{ind}', scaler)
# gradient clipping if needed
self.max_grad_norm = max_grad_norm
@property
def unets(self):
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
def scale(self, loss, *, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
scaler = getattr(self, f'scaler{index}')
return scaler.scale(loss)
def update(self, unet_number):
assert 1 <= unet_number <= self.num_unets
index = unet_number - 1
@@ -146,12 +163,36 @@ class DecoderTrainer(nn.Module):
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
optimizer = getattr(self, f'optim{index}')
optimizer.step()
scaler = getattr(self, f'scaler{index}')
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
if self.use_ema:
ema_unet = self.ema_unets[index]
ema_unet.update()
def forward(self, x, *, unet_number, **kwargs):
return self.decoder(x, unet_number = unet_number, **kwargs)
@torch.no_grad()
def sample(self, *args, **kwargs):
if self.use_ema:
trainable_unets = self.decoder.unets
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
output = self.decoder.sample(*args, **kwargs)
if self.use_ema:
self.decoder.unets = trainable_unets # restore original training unets
return output
def forward(
self,
x,
*,
unet_number,
divisor = 1,
**kwargs
):
with autocast(enabled = self.amp):
loss = self.decoder(x, unet_number = unet_number, **kwargs)
return self.scale(loss / divisor, unet_number = unet_number)

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream'
],
},
version = '0.0.78',
version = '0.0.81',
license='MIT',
description = 'DALL-E 2',
author = 'Phil Wang',