|
|
|
|
@@ -100,6 +100,9 @@ def eval_decorator(fn):
|
|
|
|
|
return out
|
|
|
|
|
return inner
|
|
|
|
|
|
|
|
|
|
def is_float_dtype(dtype):
|
|
|
|
|
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
|
|
|
|
|
|
|
|
|
|
def is_list_str(x):
|
|
|
|
|
if not isinstance(x, (list, tuple)):
|
|
|
|
|
return False
|
|
|
|
|
@@ -386,6 +389,8 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
self.eos_id = 49407
|
|
|
|
|
|
|
|
|
|
text_attention_final = self.find_layer('ln_final')
|
|
|
|
|
self._dim_latent = text_attention_final.weight.shape[0]
|
|
|
|
|
|
|
|
|
|
self.handle = text_attention_final.register_forward_hook(self._hook)
|
|
|
|
|
self.clip_normalize = preprocess.transforms[-1]
|
|
|
|
|
self.cleared = False
|
|
|
|
|
@@ -405,7 +410,7 @@ class OpenClipAdapter(BaseClipAdapter):
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def dim_latent(self):
|
|
|
|
|
return 512
|
|
|
|
|
return self._dim_latent
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def image_size(self):
|
|
|
|
|
@@ -614,7 +619,7 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
|
|
|
|
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
|
|
|
|
|
|
|
|
|
def q_sample(self, x_start, t, noise=None):
|
|
|
|
|
def q_sample(self, x_start, t, noise = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
return (
|
|
|
|
|
@@ -622,6 +627,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def calculate_v(self, x_start, t, noise = None):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
|
|
|
|
|
shape = x_from.shape
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_from))
|
|
|
|
|
@@ -633,6 +644,12 @@ class NoiseScheduler(nn.Module):
|
|
|
|
|
|
|
|
|
|
return x_from * (alpha_next / alpha) + noise * (sigma_next * alpha - sigma * alpha_next) / alpha
|
|
|
|
|
|
|
|
|
|
def predict_start_from_v(self, x_t, t, v):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def predict_start_from_noise(self, x_t, t, noise):
|
|
|
|
|
return (
|
|
|
|
|
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
|
|
|
|
@@ -968,6 +985,8 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_text_embeds)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.continuous_embedded_time = not exists(num_timesteps)
|
|
|
|
|
|
|
|
|
|
self.to_time_embeds = nn.Sequential(
|
|
|
|
|
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
|
|
|
|
|
Rearrange('b (n d) -> b n d', n = num_time_embeds)
|
|
|
|
|
@@ -1095,6 +1114,9 @@ class DiffusionPriorNetwork(nn.Module):
|
|
|
|
|
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
|
|
|
|
# but let's just do it right
|
|
|
|
|
|
|
|
|
|
if self.continuous_embedded_time:
|
|
|
|
|
diffusion_timesteps = diffusion_timesteps.type(dtype)
|
|
|
|
|
|
|
|
|
|
time_embed = self.to_time_embeds(diffusion_timesteps)
|
|
|
|
|
|
|
|
|
|
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
|
|
|
|
@@ -1136,6 +1158,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
image_cond_drop_prob = None,
|
|
|
|
|
loss_type = "l2",
|
|
|
|
|
predict_x_start = True,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
beta_schedule = "cosine",
|
|
|
|
|
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
|
|
|
|
|
sampling_clamp_l2norm = False, # whether to l2norm clamp the image embed at each denoising iteration (analogous to -1 to 1 clipping for usual DDPMs)
|
|
|
|
|
@@ -1187,6 +1210,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = predict_x_start
|
|
|
|
|
self.predict_v = predict_v # takes precedence over predict_x_start
|
|
|
|
|
|
|
|
|
|
# @crowsonkb 's suggestion - https://github.com/lucidrains/DALLE2-pytorch/issues/60#issue-1226116132
|
|
|
|
|
|
|
|
|
|
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -1289,10 +1315,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# derive x0
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_v(image_embed, t = time_cond, v = pred)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred_noise)
|
|
|
|
|
x_start = self.noise_scheduler.predict_start_from_noise(image_embed, t = time_cond, noise = pred)
|
|
|
|
|
|
|
|
|
|
# clip x0 before maybe predicting noise
|
|
|
|
|
|
|
|
|
|
@@ -1304,7 +1332,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if self.predict_x_start:
|
|
|
|
|
if self.predict_x_start or self.predict_v:
|
|
|
|
|
pred_noise = self.noise_scheduler.predict_noise_from_start(image_embed, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
@@ -1362,7 +1390,12 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
if self.predict_x_start and self.training_clamp_l2norm:
|
|
|
|
|
pred = self.l2norm_clamp_embed(pred)
|
|
|
|
|
|
|
|
|
|
target = noise if not self.predict_x_start else image_embed
|
|
|
|
|
if self.predict_v:
|
|
|
|
|
target = self.noise_scheduler.calculate_v(image_embed, times, noise)
|
|
|
|
|
elif self.predict_x_start:
|
|
|
|
|
target = image_embed
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = self.noise_scheduler.loss_fn(pred, target)
|
|
|
|
|
return loss
|
|
|
|
|
@@ -1432,7 +1465,7 @@ class DiffusionPrior(nn.Module):
|
|
|
|
|
**kwargs
|
|
|
|
|
):
|
|
|
|
|
assert exists(text) ^ exists(text_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either text or text embedding must be supplied'
|
|
|
|
|
assert exists(image) ^ exists(image_embed), 'either image or image embedding must be supplied'
|
|
|
|
|
assert not (self.condition_on_text_encodings and (not exists(text_encodings) and not exists(text))), 'text encodings must be present if you specified you wish to condition on it on initialization'
|
|
|
|
|
|
|
|
|
|
if exists(image):
|
|
|
|
|
@@ -1538,6 +1571,8 @@ class SinusoidalPosEmb(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
dtype, device = x.dtype, x.device
|
|
|
|
|
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'
|
|
|
|
|
|
|
|
|
|
half_dim = self.dim // 2
|
|
|
|
|
emb = math.log(10000) / (half_dim - 1)
|
|
|
|
|
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
|
|
|
|
|
@@ -2436,6 +2471,7 @@ class Decoder(nn.Module):
|
|
|
|
|
loss_type = 'l2',
|
|
|
|
|
beta_schedule = None,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
predict_x_start_for_latent_diffusion = False,
|
|
|
|
|
image_sizes = None, # for cascading ddpm, image size at each stage
|
|
|
|
|
random_crop_sizes = None, # whether to random crop the image at that stage in the cascade (super resoluting convolutions at the end may be able to generalize on smaller crops)
|
|
|
|
|
@@ -2608,6 +2644,10 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
self.predict_x_start = cast_tuple(predict_x_start, len(unets)) if not predict_x_start_for_latent_diffusion else tuple(map(lambda t: isinstance(t, VQGanVAE), self.vaes))
|
|
|
|
|
|
|
|
|
|
# predict v
|
|
|
|
|
|
|
|
|
|
self.predict_v = cast_tuple(predict_v, len(unets))
|
|
|
|
|
|
|
|
|
|
# input image range
|
|
|
|
|
|
|
|
|
|
self.input_image_range = (-1. if not auto_normalize_img else 0., 1.)
|
|
|
|
|
@@ -2719,14 +2759,16 @@ class Decoder(nn.Module):
|
|
|
|
|
x = x.clamp(-s, s) / s
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, predict_v = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
|
|
|
|
|
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
|
|
|
|
|
|
|
|
|
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
|
|
|
|
|
|
|
|
|
|
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
|
|
|
|
|
@@ -2753,9 +2795,9 @@ class Decoder(nn.Module):
|
|
|
|
|
return model_mean, posterior_variance, posterior_log_variance, x_start
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
def p_sample(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, cond_scale = 1., lowres_cond_img = None, self_cond = None, predict_x_start = False, predict_v = False, learned_variance = False, clip_denoised = True, lowres_noise_level = None):
|
|
|
|
|
b, *_, device = *x.shape, x.device
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
model_mean, _, model_log_variance, x_start = self.p_mean_variance(unet, x = x, t = t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, clip_denoised = clip_denoised, predict_x_start = predict_x_start, predict_v = predict_v, noise_scheduler = noise_scheduler, learned_variance = learned_variance, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
noise = torch.randn_like(x)
|
|
|
|
|
# no noise when t == 0
|
|
|
|
|
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
|
|
|
|
@@ -2770,6 +2812,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image_embed,
|
|
|
|
|
noise_scheduler,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2828,6 +2871,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
lowres_noise_level = lowres_noise_level,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
noise_scheduler = noise_scheduler,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = clip_denoised
|
|
|
|
|
@@ -2853,6 +2897,7 @@ class Decoder(nn.Module):
|
|
|
|
|
timesteps,
|
|
|
|
|
eta = 1.,
|
|
|
|
|
predict_x_start = False,
|
|
|
|
|
predict_v = False,
|
|
|
|
|
learned_variance = False,
|
|
|
|
|
clip_denoised = True,
|
|
|
|
|
lowres_cond_img = None,
|
|
|
|
|
@@ -2914,7 +2959,9 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict x0
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
if predict_v:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_v(img, t = time_cond, v = pred)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
x_start = pred
|
|
|
|
|
else:
|
|
|
|
|
x_start = noise_scheduler.predict_start_from_noise(img, t = time_cond, noise = pred)
|
|
|
|
|
@@ -2926,8 +2973,8 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
# predict noise
|
|
|
|
|
|
|
|
|
|
if predict_x_start:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = pred)
|
|
|
|
|
if predict_x_start or predict_v:
|
|
|
|
|
pred_noise = noise_scheduler.predict_noise_from_start(img, t = time_cond, x0 = x_start)
|
|
|
|
|
else:
|
|
|
|
|
pred_noise = pred
|
|
|
|
|
|
|
|
|
|
@@ -2963,7 +3010,7 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
return self.p_sample_loop_ddim(*args, noise_scheduler = noise_scheduler, timesteps = timesteps, **kwargs)
|
|
|
|
|
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
def p_losses(self, unet, x_start, times, *, image_embed, noise_scheduler, lowres_cond_img = None, text_encodings = None, predict_x_start = False, predict_v = False, noise = None, learned_variance = False, clip_denoised = False, is_latent_diffusion = False, lowres_noise_level = None):
|
|
|
|
|
noise = default(noise, lambda: torch.randn_like(x_start))
|
|
|
|
|
|
|
|
|
|
# normalize to [-1, 1]
|
|
|
|
|
@@ -3008,7 +3055,12 @@ class Decoder(nn.Module):
|
|
|
|
|
|
|
|
|
|
pred, _ = self.parse_unet_output(learned_variance, unet_output)
|
|
|
|
|
|
|
|
|
|
target = noise if not predict_x_start else x_start
|
|
|
|
|
if predict_v:
|
|
|
|
|
target = noise_scheduler.calculate_v(x_start, times, noise)
|
|
|
|
|
elif predict_x_start:
|
|
|
|
|
target = x_start
|
|
|
|
|
else:
|
|
|
|
|
target = noise
|
|
|
|
|
|
|
|
|
|
loss = noise_scheduler.loss_fn(pred, target, reduction = 'none')
|
|
|
|
|
loss = reduce(loss, 'b ... -> b (...)', 'mean')
|
|
|
|
|
@@ -3094,7 +3146,7 @@ class Decoder(nn.Module):
|
|
|
|
|
num_unets = self.num_unets
|
|
|
|
|
cond_scale = cast_tuple(cond_scale, num_unets)
|
|
|
|
|
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
for unet_number, unet, vae, channel, image_size, predict_x_start, predict_v, learned_variance, noise_scheduler, lowres_cond, sample_timesteps, unet_cond_scale in tqdm(zip(range(1, num_unets + 1), self.unets, self.vaes, self.sample_channels, self.image_sizes, self.predict_x_start, self.predict_v, self.learned_variance, self.noise_schedulers, self.lowres_conds, self.sample_timesteps, cond_scale)):
|
|
|
|
|
if unet_number < start_at_unet_number:
|
|
|
|
|
continue # It's the easiest way to do it
|
|
|
|
|
|
|
|
|
|
@@ -3130,6 +3182,7 @@ class Decoder(nn.Module):
|
|
|
|
|
text_encodings = text_encodings,
|
|
|
|
|
cond_scale = unet_cond_scale,
|
|
|
|
|
predict_x_start = predict_x_start,
|
|
|
|
|
predict_v = predict_v,
|
|
|
|
|
learned_variance = learned_variance,
|
|
|
|
|
clip_denoised = not is_latent_diffusion,
|
|
|
|
|
lowres_cond_img = lowres_cond_img,
|
|
|
|
|
@@ -3169,6 +3222,7 @@ class Decoder(nn.Module):
|
|
|
|
|
lowres_conditioner = self.lowres_conds[unet_index]
|
|
|
|
|
target_image_size = self.image_sizes[unet_index]
|
|
|
|
|
predict_x_start = self.predict_x_start[unet_index]
|
|
|
|
|
predict_v = self.predict_v[unet_index]
|
|
|
|
|
random_crop_size = self.random_crop_sizes[unet_index]
|
|
|
|
|
learned_variance = self.learned_variance[unet_index]
|
|
|
|
|
b, c, h, w, device, = *image.shape, image.device
|
|
|
|
|
@@ -3207,7 +3261,7 @@ class Decoder(nn.Module):
|
|
|
|
|
image = vae.encode(image)
|
|
|
|
|
lowres_cond_img = maybe(vae.encode)(lowres_cond_img)
|
|
|
|
|
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, predict_v = predict_v, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler, lowres_noise_level = lowres_noise_level)
|
|
|
|
|
|
|
|
|
|
if not return_lowres_cond_image:
|
|
|
|
|
return losses
|
|
|
|
|
|