mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-18 17:34:18 +01:00
bring in prediction of v objective, combining the findings from progressive distillation paper and imagen-video to the eventual extension of dalle2 to make-a-video
This commit is contained in:
10
README.md
10
README.md
@@ -1298,4 +1298,14 @@ For detailed information on training the diffusion prior, please refer to the [d
|
||||
}
|
||||
```
|
||||
|
||||
```bibtex
|
||||
@article{Salimans2022ProgressiveDF,
|
||||
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
|
||||
author = {Tim Salimans and Jonathan Ho},
|
||||
journal = {ArXiv},
|
||||
year = {2022},
|
||||
volume = {abs/2202.00512}
|
||||
}
|
||||
```
|
||||
|
||||
*Creating noise from data is easy; creating data from noise is generative modeling.* - <a href="https://arxiv.org/abs/2011.13456">Yang Song's paper</a>
|
||||
|
||||
@@ -619,7 +619,7 @@ class NoiseScheduler(nn.Module):
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
def q_sample(self, x_start, t, noise = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
@@ -627,6 +627,12 @@ class NoiseScheduler(nn.Module):
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
def calculate_v(self, x_start, t, noise = None):
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
||||
)
|
||||
|
||||
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
||||
shape = x_from.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_from))
|
||||
@@ -638,6 +644,12 @@ class NoiseScheduler(nn.Module):
|
||||
|
||||
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
||||
|
||||
def predict_start_from_v(self, x_t, t, v):
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
||||
@@ -1146,6 +1158,7 @@ class DiffusionPrior(nn.Module):
|
||||
image_cond_drop_prob = None,
|
||||
loss_type = "l2",
|
||||
predict_x_start = True,
|
||||
predict_v = False,
|
||||
beta_schedule = "cosine",
|
||||
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
||||
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
||||
@@ -1197,6 +1210,7 @@ class DiffusionPrior(nn.Module):
|
||||
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
||||
|
||||
self.predict_x_start = predict_x_start
|
||||
self.predict_v = predict_v # takes precedence over predict_x_start
|
||||
|
||||
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
||||
|
||||
@@ -1226,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
if self.predict_v:
|
||||
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
||||
elif self.predict_x_start:
|
||||
x_start = pred
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
@@ -1299,7 +1315,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# derive x0
|
||||
|
||||
if self.predict_x_start:
|
||||
if self.predict_v:
|
||||
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
|
||||
elif self.predict_x_start:
|
||||
x_start = pred
|
||||
else:
|
||||
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
|
||||
@@ -1314,7 +1332,7 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
# predict noise
|
||||
|
||||
if self.predict_x_start:
|
||||
if self.predict_x_start or self.predict_v:
|
||||
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
||||
else:
|
||||
pred_noise = pred
|
||||
@@ -1372,7 +1390,12 @@ class DiffusionPrior(nn.Module):
|
||||
if self.predict_x_start and self.training_clamp_l2norm:
|
||||
pred = self.l2norm_clamp_embed(pred)
|
||||
|
||||
target = noise if not self.predict_x_start else image_embed
|
||||
if self.predict_v:
|
||||
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
||||
elif self.predict_x_start:
|
||||
target = image_embed
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = self.noise_scheduler.loss_fn(pred, target)
|
||||
return loss
|
||||
@@ -2448,6 +2471,7 @@ class Decoder(nn.Module):
|
||||
loss_type = 'l2',
|
||||
beta_schedule = None,
|
||||
predict_x_start = False,
|
||||
predict_v = False,
|
||||
predict_x_start_for_latent_diffusion = False,
|
||||
image_sizes = None, # for cascading ddpm, image size at each stage
|
||||
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
||||
@@ -2620,6 +2644,10 @@ class Decoder(nn.Module):
|
||||
|
||||
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
||||
|
||||
# predict v
|
||||
|
||||
self.predict_v = cast_tuple(predict_v, len(unets))
|
||||
|
||||
# input image range
|
||||
|
||||
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
||||
@@ -2731,14 +2759,16 @@ class Decoder(nn.Module):
|
||||
x = x.clamp(-s, s) / s
|
||||
return x
|
||||
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
||||
|
||||
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
||||
|
||||
if predict_x_start:
|
||||
if predict_v:
|
||||
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
||||
elif predict_x_start:
|
||||
x_start = pred
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
||||
@@ -2765,9 +2795,9 @@ class Decoder(nn.Module):
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
@@ -2782,6 +2812,7 @@ class Decoder(nn.Module):
|
||||
image_embed,
|
||||
noise_scheduler,
|
||||
predict_x_start = False,
|
||||
predict_v = False,
|
||||
learned_variance = False,
|
||||
clip_denoised = True,
|
||||
lowres_cond_img = None,
|
||||
@@ -2840,6 +2871,7 @@ class Decoder(nn.Module):
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
lowres_noise_level = lowres_noise_level,
|
||||
predict_x_start = predict_x_start,
|
||||
predict_v = predict_v,
|
||||
noise_scheduler = noise_scheduler,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = clip_denoised
|
||||
@@ -2865,6 +2897,7 @@ class Decoder(nn.Module):
|
||||
timesteps,
|
||||
eta = 1.,
|
||||
predict_x_start = False,
|
||||
predict_v = False,
|
||||
learned_variance = False,
|
||||
clip_denoised = True,
|
||||
lowres_cond_img = None,
|
||||
@@ -2926,7 +2959,9 @@ class Decoder(nn.Module):
|
||||
|
||||
# predict x0
|
||||
|
||||
if predict_x_start:
|
||||
if predict_v:
|
||||
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
|
||||
elif predict_x_start:
|
||||
x_start = pred
|
||||
else:
|
||||
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
||||
@@ -2938,8 +2973,8 @@ class Decoder(nn.Module):
|
||||
|
||||
# predict noise
|
||||
|
||||
if predict_x_start:
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
||||
if predict_x_start or predict_v:
|
||||
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
||||
else:
|
||||
pred_noise = pred
|
||||
|
||||
@@ -2975,7 +3010,7 @@ class Decoder(nn.Module):
|
||||
|
||||
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
||||
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
# normalize to [-1, 1]
|
||||
@@ -3020,7 +3055,12 @@ class Decoder(nn.Module):
|
||||
|
||||
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
||||
|
||||
target = noise if not predict_x_start else x_start
|
||||
if predict_v:
|
||||
target = noise_scheduler.calculate_v(x_start, times, noise)
|
||||
elif predict_x_start:
|
||||
target = x_start
|
||||
else:
|
||||
target = noise
|
||||
|
||||
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
||||
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
||||
@@ -3106,7 +3146,7 @@ class Decoder(nn.Module):
|
||||
num_unets = self.num_unets
|
||||
cond_scale = cast_tuple(cond_scale, num_unets)
|
||||
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
||||
if unet_number < start_at_unet_number:
|
||||
continue # It's the easiest way to do it
|
||||
|
||||
@@ -3142,6 +3182,7 @@ class Decoder(nn.Module):
|
||||
text_encodings = text_encodings,
|
||||
cond_scale = unet_cond_scale,
|
||||
predict_x_start = predict_x_start,
|
||||
predict_v = predict_v,
|
||||
learned_variance = learned_variance,
|
||||
clip_denoised = not is_latent_diffusion,
|
||||
lowres_cond_img = lowres_cond_img,
|
||||
@@ -3181,6 +3222,7 @@ class Decoder(nn.Module):
|
||||
lowres_conditioner = self.lowres_conds[unet_index]
|
||||
target_image_size = self.image_sizes[unet_index]
|
||||
predict_x_start = self.predict_x_start[unet_index]
|
||||
predict_v = self.predict_v[unet_index]
|
||||
random_crop_size = self.random_crop_sizes[unet_index]
|
||||
learned_variance = self.learned_variance[unet_index]
|
||||
b, c, h, w, device, = *image.shape, image.device
|
||||
@@ -3219,7 +3261,7 @@ class Decoder(nn.Module):
|
||||
image = vae.encode(image)
|
||||
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
||||
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
||||
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
||||
|
||||
if not return_lowres_cond_image:
|
||||
return losses
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = '1.10.9'
|
||||
__version__ = '1.11.1'
|
||||
|
||||
Reference in New Issue
Block a user