mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 15:44:20 +01:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8260fc933a | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 |
10
README.md
10
README.md
@@ -760,7 +760,7 @@ decoder = Decoder(
|
|||||||
unet = (unet1, unet2),
|
unet = (unet1, unet2),
|
||||||
image_sizes = (128, 256),
|
image_sizes = (128, 256),
|
||||||
clip = clip,
|
clip = clip,
|
||||||
timesteps = 1,
|
timesteps = 1000,
|
||||||
condition_on_text_encodings = True
|
condition_on_text_encodings = True
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
@@ -778,6 +778,12 @@ for unet_number in (1, 2):
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
|
# after much training
|
||||||
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
|
mock_image_embed = torch.randn(4, 512).cuda()
|
||||||
|
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
```
|
```
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
@@ -811,7 +817,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
- [ ] take care of mixed precision as well as gradient accumulation within decoder trainer
|
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
|||||||
@@ -1540,7 +1540,13 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(
|
||||||
|
self,
|
||||||
|
image_embed,
|
||||||
|
text = None,
|
||||||
|
cond_scale = 1.,
|
||||||
|
stop_at_unet_number = None
|
||||||
|
):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = text_mask = None
|
text_encodings = text_mask = None
|
||||||
@@ -1552,7 +1558,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||||
|
|
||||||
@@ -1584,6 +1590,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
|
|
||||||
|
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||||
|
break
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from functools import partial
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
from dalle2_pytorch.dalle2_pytorch import Decoder
|
from dalle2_pytorch.dalle2_pytorch import Decoder
|
||||||
from dalle2_pytorch.optimizer import get_optimizer
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
@@ -98,6 +99,7 @@ class DecoderTrainer(nn.Module):
|
|||||||
lr = 3e-4,
|
lr = 3e-4,
|
||||||
wd = 1e-2,
|
wd = 1e-2,
|
||||||
max_grad_norm = None,
|
max_grad_norm = None,
|
||||||
|
amp = False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -115,6 +117,8 @@ class DecoderTrainer(nn.Module):
|
|||||||
|
|
||||||
self.ema_unets = nn.ModuleList([])
|
self.ema_unets = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.amp = amp
|
||||||
|
|
||||||
# be able to finely customize learning rate, weight decay
|
# be able to finely customize learning rate, weight decay
|
||||||
# per unet
|
# per unet
|
||||||
|
|
||||||
@@ -133,10 +137,23 @@ class DecoderTrainer(nn.Module):
|
|||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled = amp)
|
||||||
|
setattr(self, f'scaler{ind}', scaler)
|
||||||
|
|
||||||
# gradient clipping if needed
|
# gradient clipping if needed
|
||||||
|
|
||||||
self.max_grad_norm = max_grad_norm
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unets(self):
|
||||||
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
|
def scale(self, loss, *, unet_number):
|
||||||
|
assert 1 <= unet_number <= self.num_unets
|
||||||
|
index = unet_number - 1
|
||||||
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
return scaler.scale(loss)
|
||||||
|
|
||||||
def update(self, unet_number):
|
def update(self, unet_number):
|
||||||
assert 1 <= unet_number <= self.num_unets
|
assert 1 <= unet_number <= self.num_unets
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -146,12 +163,36 @@ class DecoderTrainer(nn.Module):
|
|||||||
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
optimizer = getattr(self, f'optim{index}')
|
optimizer = getattr(self, f'optim{index}')
|
||||||
optimizer.step()
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
ema_unet = self.ema_unets[index]
|
ema_unet = self.ema_unets[index]
|
||||||
ema_unet.update()
|
ema_unet.update()
|
||||||
|
|
||||||
def forward(self, x, *, unet_number, **kwargs):
|
@torch.no_grad()
|
||||||
return self.decoder(x, unet_number = unet_number, **kwargs)
|
def sample(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
unet_number,
|
||||||
|
divisor = 1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||||
|
return self.scale(loss / divisor, unet_number = unet_number)
|
||||||
|
|||||||
Reference in New Issue
Block a user