mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-14 21:44:29 +01:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8260fc933a | ||
|
|
ebe01749ed | ||
|
|
63195cc2cb | ||
|
|
a2ef69af66 | ||
|
|
5fff22834e | ||
|
|
a9421f49ec | ||
|
|
77fa34eae9 | ||
|
|
1c1e508369 |
79
README.md
79
README.md
@@ -708,7 +708,83 @@ images = decoder.sample(mock_image_embed) # (1, 3, 1024, 1024)
|
|||||||
|
|
||||||
## Training wrapper (wip)
|
## Training wrapper (wip)
|
||||||
|
|
||||||
Offer training wrappers
|
### Decoder Training
|
||||||
|
|
||||||
|
Training the `Decoder` may be confusing, as one needs to keep track of an optimizer for each of the `Unet`(s) separately. Each `Unet` will also need its own corresponding exponential moving average. The `DecoderTrainer` hopes to make this simple, as shown below
|
||||||
|
|
||||||
|
```python
|
||||||
|
import torch
|
||||||
|
from dalle2_pytorch import DALLE2, Unet, Decoder, CLIP, DecoderTrainer
|
||||||
|
|
||||||
|
clip = CLIP(
|
||||||
|
dim_text = 512,
|
||||||
|
dim_image = 512,
|
||||||
|
dim_latent = 512,
|
||||||
|
num_text_tokens = 49408,
|
||||||
|
text_enc_depth = 6,
|
||||||
|
text_seq_len = 256,
|
||||||
|
text_heads = 8,
|
||||||
|
visual_enc_depth = 6,
|
||||||
|
visual_image_size = 256,
|
||||||
|
visual_patch_size = 32,
|
||||||
|
visual_heads = 8
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
# mock data
|
||||||
|
|
||||||
|
text = torch.randint(0, 49408, (4, 256)).cuda()
|
||||||
|
images = torch.randn(4, 3, 256, 256).cuda()
|
||||||
|
|
||||||
|
# decoder (with unet)
|
||||||
|
|
||||||
|
unet1 = Unet(
|
||||||
|
dim = 128,
|
||||||
|
image_embed_dim = 512,
|
||||||
|
text_embed_dim = 512,
|
||||||
|
cond_dim = 128,
|
||||||
|
channels = 3,
|
||||||
|
dim_mults=(1, 2, 4, 8)
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
unet2 = Unet(
|
||||||
|
dim = 16,
|
||||||
|
image_embed_dim = 512,
|
||||||
|
text_embed_dim = 512,
|
||||||
|
cond_dim = 128,
|
||||||
|
channels = 3,
|
||||||
|
dim_mults = (1, 2, 4, 8, 16),
|
||||||
|
cond_on_text_encodings = True
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
decoder = Decoder(
|
||||||
|
unet = (unet1, unet2),
|
||||||
|
image_sizes = (128, 256),
|
||||||
|
clip = clip,
|
||||||
|
timesteps = 1000,
|
||||||
|
condition_on_text_encodings = True
|
||||||
|
).cuda()
|
||||||
|
|
||||||
|
decoder_trainer = DecoderTrainer(
|
||||||
|
decoder,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
ema_beta = 0.99,
|
||||||
|
ema_update_after_step = 1000,
|
||||||
|
ema_update_every = 10,
|
||||||
|
)
|
||||||
|
|
||||||
|
for unet_number in (1, 2):
|
||||||
|
loss = decoder_trainer(images, text = text, unet_number = unet_number) # use the decoder_trainer forward
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
decoder_trainer.update(unet_number) # update the specific unet as well as its exponential moving average
|
||||||
|
|
||||||
|
# after much training
|
||||||
|
# you can sample from the exponentially moving averaged unets as so
|
||||||
|
|
||||||
|
mock_image_embed = torch.randn(4, 512).cuda()
|
||||||
|
images = decoder.sample(mock_image_embed, text = text) # (4, 3, 256, 256)
|
||||||
|
```
|
||||||
|
|
||||||
## CLI (wip)
|
## CLI (wip)
|
||||||
|
|
||||||
@@ -741,6 +817,7 @@ Once built, images will be saved to the same directory the command is invoked
|
|||||||
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
- [x] use inheritance just this once for sharing logic between decoder and prior network ddpms
|
||||||
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
- [x] bring in vit-vqgan https://arxiv.org/abs/2110.04627 for the latent diffusion
|
||||||
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
- [x] abstract interface for CLIP adapter class, so other CLIPs can be brought in
|
||||||
|
- [x] take care of mixed precision as well as gradient accumulation within decoder trainer
|
||||||
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
- [ ] become an expert with unets, cleanup unet code, make it fully configurable, port all learnings over to https://github.com/lucidrains/x-unet
|
||||||
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
- [ ] copy the cascading ddpm code to a separate repo (perhaps https://github.com/lucidrains/denoising-diffusion-pytorch) as the main contribution of dalle2 really is just the prior network
|
||||||
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
- [ ] transcribe code to Jax, which lowers the activation energy for distributed training, given access to TPUs
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
from dalle2_pytorch.dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder
|
||||||
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
from dalle2_pytorch.dalle2_pytorch import OpenAIClipAdapter
|
||||||
|
from dalle2_pytorch.train import DecoderTrainer
|
||||||
|
|
||||||
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
from dalle2_pytorch.vqgan_vae import VQGanVAE
|
||||||
from x_clip import CLIP
|
from x_clip import CLIP
|
||||||
|
|||||||
@@ -736,6 +736,7 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
predict_x_start = True,
|
predict_x_start = True,
|
||||||
beta_schedule = "cosine",
|
beta_schedule = "cosine",
|
||||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||||
|
sampling_clamp_l2norm = False
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -764,6 +765,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
self.predict_x_start = predict_x_start
|
self.predict_x_start = predict_x_start
|
||||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||||
|
|
||||||
|
# whether to force an l2norm, similar to clipping denoised, when sampling
|
||||||
|
self.sampling_clamp_l2norm = sampling_clamp_l2norm
|
||||||
|
|
||||||
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
def p_mean_variance(self, x, t, text_cond, clip_denoised: bool):
|
||||||
pred = self.net(x, t, **text_cond)
|
pred = self.net(x, t, **text_cond)
|
||||||
|
|
||||||
@@ -777,6 +781,9 @@ class DiffusionPrior(BaseGaussianDiffusion):
|
|||||||
if clip_denoised and not self.predict_x_start:
|
if clip_denoised and not self.predict_x_start:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
|
if self.predict_x_start and self.sampling_clamp_l2norm:
|
||||||
|
x_recon = l2norm(x_recon)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@@ -1090,7 +1097,12 @@ class Unet(nn.Module):
|
|||||||
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
Rearrange('b (n d) -> b n d', n = num_image_tokens)
|
||||||
) if image_embed_dim != cond_dim else nn.Identity()
|
) if image_embed_dim != cond_dim else nn.Identity()
|
||||||
|
|
||||||
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
# text encoding conditioning (optional)
|
||||||
|
|
||||||
|
self.text_to_cond = None
|
||||||
|
|
||||||
|
if cond_on_text_encodings:
|
||||||
|
self.text_to_cond = nn.LazyLinear(cond_dim) if not exists(text_embed_dim) else nn.Linear(text_embed_dim, cond_dim)
|
||||||
|
|
||||||
# finer control over whether to condition on image embeddings and text encodings
|
# finer control over whether to condition on image embeddings and text encodings
|
||||||
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
# so one can have the latter unets in the cascading DDPMs only focus on super-resoluting
|
||||||
@@ -1101,6 +1113,8 @@ class Unet(nn.Module):
|
|||||||
# for classifier free guidance
|
# for classifier free guidance
|
||||||
|
|
||||||
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
self.null_image_embed = nn.Parameter(torch.randn(1, num_image_tokens, cond_dim))
|
||||||
|
|
||||||
|
self.max_text_len = max_text_len
|
||||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
|
||||||
|
|
||||||
# attention related params
|
# attention related params
|
||||||
@@ -1185,6 +1199,7 @@ class Unet(nn.Module):
|
|||||||
image_embed,
|
image_embed,
|
||||||
lowres_cond_img = None,
|
lowres_cond_img = None,
|
||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
|
text_mask = None,
|
||||||
image_cond_drop_prob = 0.,
|
image_cond_drop_prob = 0.,
|
||||||
text_cond_drop_prob = 0.,
|
text_cond_drop_prob = 0.,
|
||||||
blur_sigma = None,
|
blur_sigma = None,
|
||||||
@@ -1230,10 +1245,25 @@ class Unet(nn.Module):
|
|||||||
|
|
||||||
if exists(text_encodings) and self.cond_on_text_encodings:
|
if exists(text_encodings) and self.cond_on_text_encodings:
|
||||||
text_tokens = self.text_to_cond(text_encodings)
|
text_tokens = self.text_to_cond(text_encodings)
|
||||||
|
text_tokens = text_tokens[:, :self.max_text_len]
|
||||||
|
|
||||||
|
text_tokens_len = text_tokens.shape[1]
|
||||||
|
remainder = self.max_text_len - text_tokens_len
|
||||||
|
|
||||||
|
if remainder > 0:
|
||||||
|
text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
|
||||||
|
|
||||||
|
if exists(text_mask):
|
||||||
|
if remainder > 0:
|
||||||
|
text_mask = F.pad(text_mask, (0, remainder), value = False)
|
||||||
|
|
||||||
|
text_mask = rearrange(text_mask, 'b n -> b n 1')
|
||||||
|
text_keep_mask = text_mask & text_keep_mask
|
||||||
|
|
||||||
text_tokens = torch.where(
|
text_tokens = torch.where(
|
||||||
text_keep_mask,
|
text_keep_mask,
|
||||||
text_tokens,
|
text_tokens,
|
||||||
self.null_text_embed[:, :text_tokens.shape[1]]
|
self.null_text_embed
|
||||||
)
|
)
|
||||||
|
|
||||||
# main conditioning tokens (c)
|
# main conditioning tokens (c)
|
||||||
@@ -1333,6 +1363,8 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
blur_sigma = 0.1, # cascading ddpm - blur sigma
|
||||||
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
blur_kernel_size = 3, # cascading ddpm - blur kernel size
|
||||||
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
condition_on_text_encodings = False, # the paper suggested that this didn't do much in the decoder, but i'm allowing the option for experimentation
|
||||||
|
clip_denoised = True,
|
||||||
|
clip_x_start = True
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
beta_schedule = beta_schedule,
|
beta_schedule = beta_schedule,
|
||||||
@@ -1409,6 +1441,11 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
self.image_cond_drop_prob = image_cond_drop_prob
|
self.image_cond_drop_prob = image_cond_drop_prob
|
||||||
self.text_cond_drop_prob = text_cond_drop_prob
|
self.text_cond_drop_prob = text_cond_drop_prob
|
||||||
|
|
||||||
|
# whether to clip when sampling
|
||||||
|
|
||||||
|
self.clip_denoised = clip_denoised
|
||||||
|
self.clip_x_start = clip_x_start
|
||||||
|
|
||||||
def get_unet(self, unet_number):
|
def get_unet(self, unet_number):
|
||||||
assert 0 < unet_number <= len(self.unets)
|
assert 0 < unet_number <= len(self.unets)
|
||||||
index = unet_number - 1
|
index = unet_number - 1
|
||||||
@@ -1434,31 +1471,31 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
return image_embed
|
return image_embed
|
||||||
|
|
||||||
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
def p_mean_variance(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, lowres_cond_img = None, clip_denoised = True, predict_x_start = False, cond_scale = 1.):
|
||||||
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
pred = unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img)
|
||||||
|
|
||||||
if predict_x_start:
|
if predict_x_start:
|
||||||
x_recon = pred
|
x_recon = pred
|
||||||
else:
|
else:
|
||||||
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
x_recon = self.predict_start_from_noise(x, t = t, noise = pred)
|
||||||
|
|
||||||
if clip_denoised and not predict_x_start:
|
if clip_denoised:
|
||||||
x_recon.clamp_(-1., 1.)
|
x_recon.clamp_(-1., 1.)
|
||||||
|
|
||||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
||||||
return model_mean, posterior_variance, posterior_log_variance
|
return model_mean, posterior_variance, posterior_log_variance
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample(self, unet, x, t, image_embed, text_encodings = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
def p_sample(self, unet, x, t, image_embed, text_encodings = None, text_mask = None, cond_scale = 1., lowres_cond_img = None, predict_x_start = False, clip_denoised = True, repeat_noise = False):
|
||||||
b, *_, device = *x.shape, x.device
|
b, *_, device = *x.shape, x.device
|
||||||
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
model_mean, _, model_log_variance = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, clip_denoised = clip_denoised, predict_x_start = predict_x_start)
|
||||||
noise = noise_like(x.shape, device, repeat_noise)
|
noise = noise_like(x.shape, device, repeat_noise)
|
||||||
# no noise when t == 0
|
# no noise when t == 0
|
||||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, lowres_cond_img = None, text_encodings = None, cond_scale = 1):
|
def p_sample_loop(self, unet, shape, image_embed, predict_x_start = False, clip_denoised = True, lowres_cond_img = None, text_encodings = None, text_mask = None, cond_scale = 1):
|
||||||
device = self.betas.device
|
device = self.betas.device
|
||||||
|
|
||||||
b = shape[0]
|
b = shape[0]
|
||||||
@@ -1471,14 +1508,16 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
torch.full((b,), i, device = device, dtype = torch.long),
|
torch.full((b,), i, device = device, dtype = torch.long),
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
|
text_mask = text_mask,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
predict_x_start = predict_x_start
|
predict_x_start = predict_x_start,
|
||||||
|
clip_denoised = clip_denoised
|
||||||
)
|
)
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None):
|
def p_losses(self, unet, x_start, times, *, image_embed, lowres_cond_img = None, text_encodings = None, text_mask = None, predict_x_start = False, noise = None):
|
||||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||||
|
|
||||||
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
x_noisy = self.q_sample(x_start = x_start, t = times, noise = noise)
|
||||||
@@ -1488,6 +1527,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
times,
|
times,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
|
text_mask = text_mask,
|
||||||
lowres_cond_img = lowres_cond_img,
|
lowres_cond_img = lowres_cond_img,
|
||||||
image_cond_drop_prob = self.image_cond_drop_prob,
|
image_cond_drop_prob = self.image_cond_drop_prob,
|
||||||
text_cond_drop_prob = self.text_cond_drop_prob,
|
text_cond_drop_prob = self.text_cond_drop_prob,
|
||||||
@@ -1500,19 +1540,25 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
@eval_decorator
|
@eval_decorator
|
||||||
def sample(self, image_embed, text = None, cond_scale = 1.):
|
def sample(
|
||||||
|
self,
|
||||||
|
image_embed,
|
||||||
|
text = None,
|
||||||
|
cond_scale = 1.,
|
||||||
|
stop_at_unet_number = None
|
||||||
|
):
|
||||||
batch_size = image_embed.shape[0]
|
batch_size = image_embed.shape[0]
|
||||||
|
|
||||||
text_encodings = None
|
text_encodings = text_mask = None
|
||||||
if exists(text):
|
if exists(text):
|
||||||
_, text_encodings, _ = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||||
|
|
||||||
img = None
|
img = None
|
||||||
|
|
||||||
for unet, vae, channel, image_size, predict_x_start in tqdm(zip(self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
for unet_number, unet, vae, channel, image_size, predict_x_start in tqdm(zip(range(1, len(self.unets) + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start)):
|
||||||
|
|
||||||
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
context = self.one_unet_in_gpu(unet = unet) if image_embed.is_cuda else null_context()
|
||||||
|
|
||||||
@@ -1523,6 +1569,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
if unet.lowres_cond:
|
if unet.lowres_cond:
|
||||||
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
lowres_cond_img = self.to_lowres_cond(img, target_image_size = image_size)
|
||||||
|
|
||||||
|
is_latent_diffusion = isinstance(vae, VQGanVAE)
|
||||||
image_size = vae.get_encoded_fmap_size(image_size)
|
image_size = vae.get_encoded_fmap_size(image_size)
|
||||||
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
shape = (batch_size, vae.encoded_dim, image_size, image_size)
|
||||||
|
|
||||||
@@ -1534,13 +1581,18 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
shape,
|
shape,
|
||||||
image_embed = image_embed,
|
image_embed = image_embed,
|
||||||
text_encodings = text_encodings,
|
text_encodings = text_encodings,
|
||||||
|
text_mask = text_mask,
|
||||||
cond_scale = cond_scale,
|
cond_scale = cond_scale,
|
||||||
predict_x_start = predict_x_start,
|
predict_x_start = predict_x_start,
|
||||||
|
clip_denoised = not is_latent_diffusion,
|
||||||
lowres_cond_img = lowres_cond_img
|
lowres_cond_img = lowres_cond_img
|
||||||
)
|
)
|
||||||
|
|
||||||
img = vae.decode(img)
|
img = vae.decode(img)
|
||||||
|
|
||||||
|
if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
|
||||||
|
break
|
||||||
|
|
||||||
return img
|
return img
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1571,9 +1623,9 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
if not exists(image_embed):
|
if not exists(image_embed):
|
||||||
image_embed, _ = self.clip.embed_image(image)
|
image_embed, _ = self.clip.embed_image(image)
|
||||||
|
|
||||||
text_encodings = None
|
text_encodings = text_mask = None
|
||||||
if exists(text) and not exists(text_encodings):
|
if exists(text) and not exists(text_encodings):
|
||||||
_, text_encodings, _ = self.clip.embed_text(text)
|
_, text_encodings, text_mask = self.clip.embed_text(text)
|
||||||
|
|
||||||
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
assert not (self.condition_on_text_encodings and not exists(text_encodings)), 'text or text encodings must be passed into decoder if specified'
|
||||||
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
assert not (not self.condition_on_text_encodings and exists(text_encodings)), 'decoder specified not to be conditioned on text, yet it is presented'
|
||||||
@@ -1588,7 +1640,7 @@ class Decoder(BaseGaussianDiffusion):
|
|||||||
if exists(lowres_cond_img):
|
if exists(lowres_cond_img):
|
||||||
lowres_cond_img = vae.encode(lowres_cond_img)
|
lowres_cond_img = vae.encode(lowres_cond_img)
|
||||||
|
|
||||||
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
return self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start)
|
||||||
|
|
||||||
# main class
|
# main class
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,43 @@
|
|||||||
import copy
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast, GradScaler
|
||||||
|
|
||||||
|
from dalle2_pytorch.dalle2_pytorch import Decoder
|
||||||
|
from dalle2_pytorch.optimizer import get_optimizer
|
||||||
|
|
||||||
|
# helper functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
return val is not None
|
||||||
|
|
||||||
|
def cast_tuple(val, length = 1):
|
||||||
|
return val if isinstance(val, tuple) else ((val,) * length)
|
||||||
|
|
||||||
|
def pick_and_pop(keys, d):
|
||||||
|
values = list(map(lambda key: d.pop(key), keys))
|
||||||
|
return dict(zip(keys, values))
|
||||||
|
|
||||||
|
def group_dict_by_key(cond, d):
|
||||||
|
return_val = [dict(),dict()]
|
||||||
|
for key in d.keys():
|
||||||
|
match = bool(cond(key))
|
||||||
|
ind = int(not match)
|
||||||
|
return_val[ind][key] = d[key]
|
||||||
|
return (*return_val,)
|
||||||
|
|
||||||
|
def string_begins_with(prefix, str):
|
||||||
|
return str.startswith(prefix)
|
||||||
|
|
||||||
|
def group_by_key_prefix(prefix, d):
|
||||||
|
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
|
||||||
|
def groupby_prefix_and_trim(prefix, d):
|
||||||
|
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||||
|
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||||
|
return kwargs_without_prefix, kwargs
|
||||||
|
|
||||||
# exponential moving average wrapper
|
# exponential moving average wrapper
|
||||||
|
|
||||||
@@ -9,16 +46,16 @@ class EMA(nn.Module):
|
|||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
beta = 0.99,
|
beta = 0.99,
|
||||||
ema_update_after_step = 1000,
|
update_after_step = 1000,
|
||||||
ema_update_every = 10,
|
update_every = 10,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.beta = beta
|
self.beta = beta
|
||||||
self.online_model = model
|
self.online_model = model
|
||||||
self.ema_model = copy.deepcopy(model)
|
self.ema_model = copy.deepcopy(model)
|
||||||
|
|
||||||
self.ema_update_after_step = ema_update_after_step # only start EMA after this step number, starting at 0
|
self.update_after_step = update_after_step # only start EMA after this step number, starting at 0
|
||||||
self.ema_update_every = ema_update_every
|
self.update_every = update_every
|
||||||
|
|
||||||
self.register_buffer('initted', torch.Tensor([False]))
|
self.register_buffer('initted', torch.Tensor([False]))
|
||||||
self.register_buffer('step', torch.tensor([0.]))
|
self.register_buffer('step', torch.tensor([0.]))
|
||||||
@@ -26,7 +63,7 @@ class EMA(nn.Module):
|
|||||||
def update(self):
|
def update(self):
|
||||||
self.step += 1
|
self.step += 1
|
||||||
|
|
||||||
if self.step <= self.ema_update_after_step or (self.step % self.ema_update_every) != 0:
|
if self.step <= self.update_after_step or (self.step % self.update_every) != 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.initted:
|
if not self.initted:
|
||||||
@@ -51,3 +88,111 @@ class EMA(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
def __call__(self, *args, **kwargs):
|
||||||
return self.ema_model(*args, **kwargs)
|
return self.ema_model(*args, **kwargs)
|
||||||
|
|
||||||
|
# trainers
|
||||||
|
|
||||||
|
class DecoderTrainer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decoder,
|
||||||
|
use_ema = True,
|
||||||
|
lr = 3e-4,
|
||||||
|
wd = 1e-2,
|
||||||
|
max_grad_norm = None,
|
||||||
|
amp = False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(decoder, Decoder)
|
||||||
|
ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
|
||||||
|
|
||||||
|
self.decoder = decoder
|
||||||
|
self.num_unets = len(self.decoder.unets)
|
||||||
|
|
||||||
|
self.use_ema = use_ema
|
||||||
|
|
||||||
|
if use_ema:
|
||||||
|
has_lazy_linear = any([type(module) == nn.LazyLinear for module in decoder.modules()])
|
||||||
|
assert not has_lazy_linear, 'you must set the text_embed_dim on your u-nets if you plan on doing automatic exponential moving average'
|
||||||
|
|
||||||
|
self.ema_unets = nn.ModuleList([])
|
||||||
|
|
||||||
|
self.amp = amp
|
||||||
|
|
||||||
|
# be able to finely customize learning rate, weight decay
|
||||||
|
# per unet
|
||||||
|
|
||||||
|
lr, wd = map(partial(cast_tuple, length = self.num_unets), (lr, wd))
|
||||||
|
|
||||||
|
for ind, (unet, unet_lr, unet_wd) in enumerate(zip(self.decoder.unets, lr, wd)):
|
||||||
|
optimizer = get_optimizer(
|
||||||
|
unet.parameters(),
|
||||||
|
lr = unet_lr,
|
||||||
|
wd = unet_wd,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.ema_unets.append(EMA(unet, **ema_kwargs))
|
||||||
|
|
||||||
|
scaler = GradScaler(enabled = amp)
|
||||||
|
setattr(self, f'scaler{ind}', scaler)
|
||||||
|
|
||||||
|
# gradient clipping if needed
|
||||||
|
|
||||||
|
self.max_grad_norm = max_grad_norm
|
||||||
|
|
||||||
|
@property
|
||||||
|
def unets(self):
|
||||||
|
return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
|
||||||
|
|
||||||
|
def scale(self, loss, *, unet_number):
|
||||||
|
assert 1 <= unet_number <= self.num_unets
|
||||||
|
index = unet_number - 1
|
||||||
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
return scaler.scale(loss)
|
||||||
|
|
||||||
|
def update(self, unet_number):
|
||||||
|
assert 1 <= unet_number <= self.num_unets
|
||||||
|
index = unet_number - 1
|
||||||
|
unet = self.decoder.unets[index]
|
||||||
|
|
||||||
|
if exists(self.max_grad_norm):
|
||||||
|
nn.utils.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
optimizer = getattr(self, f'optim{index}')
|
||||||
|
scaler = getattr(self, f'scaler{index}')
|
||||||
|
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
ema_unet = self.ema_unets[index]
|
||||||
|
ema_unet.update()
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def sample(self, *args, **kwargs):
|
||||||
|
if self.use_ema:
|
||||||
|
trainable_unets = self.decoder.unets
|
||||||
|
self.decoder.unets = self.unets # swap in exponential moving averaged unets for sampling
|
||||||
|
|
||||||
|
output = self.decoder.sample(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.use_ema:
|
||||||
|
self.decoder.unets = trainable_unets # restore original training unets
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
*,
|
||||||
|
unet_number,
|
||||||
|
divisor = 1,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
with autocast(enabled = self.amp):
|
||||||
|
loss = self.decoder(x, unet_number = unet_number, **kwargs)
|
||||||
|
return self.scale(loss / divisor, unet_number = unet_number)
|
||||||
|
|||||||
Reference in New Issue
Block a user