|
|
|
|
@@ -389,6 +389,8 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
self.eos_id = 49407
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
self._dim_latent = text_attention_final.weight.shape[0]
|
|
|
|
|
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
|
self.cleared = False
|
|
|
|
|
@@ -408,7 +410,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dim_latent(self):
|
|
|
|
|
return 512
|
|
|
|
|
return self._dim_latent
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_size(self):
|
|
|
|
|
@@ -617,7 +619,7 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
def q_sample(self, x_start, t, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
@@ -625,6 +627,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def calculate_v(self, x_start, t, noise = None):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
|
|
|
|
shape = x_from.shape
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_from))
|
|
|
|
|
@@ -636,6 +644,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
|
|
|
|
|
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
|
|
|
|
|
|
|
|
|
def predict_start_from_v(self, x_t, t, v):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
@@ -1144,6 +1158,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
image_cond_drop_prob = None,
|
|
|
|
|
loss_type = "l2",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
|
|
|
|
@@ -1195,6 +1210,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
self.predict_v = predict_v # takes precedence over predict_x_start
|
|
|
|
|
|
|
|
|
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
|
|
|
|
|
@@ -1224,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -1297,7 +1315,9 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# derive x0
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
|
|
|
|
|
@@ -1312,7 +1332,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_x_start or self.predict_v:
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
@@ -1370,7 +1390,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
if self.predict_x_start and self.training_clamp_l2norm:
|
|
|
|
|
pred = self.l2norm_clamp_embed(pred)
|
|
|
|
|
|
|
|
|
|
target = noise if not self.predict_x_start else image_embed
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
target = image_embed
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = self.noise_scheduler.loss_fn(pred, target)
|
|
|
|
|
return loss
|
|
|
|
|
@@ -2446,6 +2471,7 @@ class Decoder(nn.Module):
|
|
|
|
|
loss_type = 'l2',
|
|
|
|
|
beta_schedule = None,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
predict_x_start_for_latent_diffusion = False,
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
@@ -2618,6 +2644,10 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
|
|
|
|
|
|
|
|
|
# predict v
|
|
|
|
|
|
|
|
|
|
self.predict_v = cast_tuple(predict_v, len(unets))
|
|
|
|
|
|
|
|
|
|
# input image range
|
|
|
|
|
|
|
|
|
|
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
|
|
|
|
@@ -2729,14 +2759,16 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
|
|
|
|
|
|
|
|
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -2763,9 +2795,9 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
@@ -2780,6 +2812,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image_embed,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2838,6 +2871,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = clip_denoised
|
|
|
|
|
@@ -2863,6 +2897,7 @@ class Decoder(nn.Module):
|
|
|
|
|
timesteps,
|
|
|
|
|
eta = 1.,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2924,7 +2959,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict x0
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
|
|
|
|
@@ -2936,8 +2973,8 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
|
|
|
|
if predict_x_start or predict_v:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
|
|
|
|
|
@@ -2973,7 +3010,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
@@ -3018,7 +3055,12 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
|
|
|
|
|
|
|
|
|
target = noise if not predict_x_start else x_start
|
|
|
|
|
if predict_v:
|
|
|
|
|
target = noise_scheduler.calculate_v(x_start, times, noise)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
target = x_start
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
|
|
|
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
|
|
|
@@ -3104,7 +3146,7 @@ class Decoder(nn.Module):
|
|
|
|
|
num_unets = self.num_unets
|
|
|
|
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
if unet_number < start_at_unet_number:
|
|
|
|
|
continue # It's the easiest way to do it
|
|
|
|
|
|
|
|
|
|
@@ -3140,6 +3182,7 @@ class Decoder(nn.Module):
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
cond_scale = unet_cond_scale,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = not is_latent_diffusion,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
@@ -3179,6 +3222,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_conditioner = self.lowres_conds[unet_index]
|
|
|
|
|
target_image_size = self.image_sizes[unet_index]
|
|
|
|
|
predict_x_start = self.predict_x_start[unet_index]
|
|
|
|
|
predict_v = self.predict_v[unet_index]
|
|
|
|
|
random_crop_size = self.random_crop_sizes[unet_index]
|
|
|
|
|
learned_variance = self.learned_variance[unet_index]
|
|
|
|
|
b, c, h, w, device, = *image.shape, image.device
|
|
|
|
|
@@ -3217,7 +3261,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image = vae.encode(image)
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if not return_lowres_cond_image:
|
|
|
|
|
return losses
|
|
|
|
|
|