Compare commits

..

5 Commits
1.6.0 ... 1.6.5

Author SHA1 Message Date
Phil Wang
34806663e3 make it so diffusion prior p_sample_loop returns unnormalized image embeddings 2022-08-13 10:03:40 -07:00
Phil Wang
dc816b1b6e dry up some code around handling unet outputs with learned variance 2022-08-12 15:25:03 -07:00
Phil Wang
05192ffac4 fix self conditioning shape in diffusion prior 2022-08-12 12:30:03 -07:00
Phil Wang
9440411954 make self conditioning technique work with diffusion prior 2022-08-12 12:20:51 -07:00
Phil Wang
981d407792 comment 2022-08-12 11:41:23 -07:00
2 changed files with 67 additions and 32 deletions

View File

@@ -38,6 +38,8 @@ from coca_pytorch import CoCa
NAT = 1. / math.log(2.)
UnetOutput = namedtuple('UnetOutput', ['pred', 'var_interp_frac_unnormalized'])
# helper functions
def exists(val):
@@ -937,9 +939,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
self_cond = False,
**kwargs
):
super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
@@ -967,6 +972,10 @@ class DiffusionPriorNetwork(nn.Module):
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale(
self,
*args,
@@ -988,12 +997,19 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
self_cond = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
self_cond = rearrange(self_cond, 'b d -> b 1 d')
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -1043,13 +1059,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat((
text_encodings,
text_embed,
@@ -1151,10 +1170,10 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1168,28 +1187,33 @@ class DiffusionPrior(nn.Module):
x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1207,6 +1231,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1216,7 +1242,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1251,18 +1279,27 @@ class DiffusionPrior(nn.Module):
is_ddim = timesteps < self.noise_scheduler.num_timesteps
if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
image_embed = normalized_image_embed / self.image_embed_scale
return image_embed
def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net(
image_embed_noisy,
times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
@@ -1316,8 +1353,6 @@ class DiffusionPrior(nn.Module):
# retrieve original unscaled image embed
image_embeds /= self.image_embed_scale
text_embeds = text_cond['text_embed']
text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
@@ -1700,7 +1735,7 @@ class Unet(nn.Module):
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False,
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,
@@ -2552,6 +2587,14 @@ class Decoder(nn.Module):
index = unet_number - 1
return self.unets[index]
def parse_unet_output(self, learned_variance, output):
var_interp_frac_unnormalized = None
if learned_variance:
output, var_interp_frac_unnormalized = output.chunk(2, dim = 1)
return UnetOutput(output, var_interp_frac_unnormalized)
@contextmanager
def one_unet_in_gpu(self, unet_number = None, unet = None):
assert exists(unet_number) ^ exists(unet)
@@ -2593,10 +2636,9 @@ class Decoder(nn.Module):
def p_mean_variance(self, unet, x, t, image_embed, noise_scheduler, text_encodings = None, lowres_cond_img = None, self_cond = None, clip_denoised = True, predict_x_start = False, learned_variance = False, cond_scale = 1., model_output = None, lowres_noise_level = None):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the decoder was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
model_output = default(model_output, lambda: unet.forward_with_cond_scale(x, t, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, lowres_cond_img = lowres_cond_img, self_cond = self_cond, lowres_noise_level = lowres_noise_level))
if learned_variance:
pred, var_interp_frac_unnormalized = pred.chunk(2, dim = 1)
pred, var_interp_frac_unnormalized = self.parse_unet_output(learned_variance, model_output)
if predict_x_start:
x_start = pred
@@ -2779,10 +2821,9 @@ class Decoder(nn.Module):
self_cond = x_start if unet.self_cond else None
pred = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
unet_output = unet.forward_with_cond_scale(img, time_cond, image_embed = image_embed, text_encodings = text_encodings, cond_scale = cond_scale, self_cond = self_cond, lowres_cond_img = lowres_cond_img, lowres_noise_level = lowres_noise_level)
if learned_variance:
pred, _ = pred.chunk(2, dim = 1)
pred, _ = self.parse_unet_output(learned_variance, unet_output)
if predict_x_start:
x_start = pred
@@ -2854,16 +2895,13 @@ class Decoder(nn.Module):
if unet.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = unet(x_noisy, times, **unet_kwargs)
if learned_variance:
self_cond, _ = self_cond.chunk(2, dim = 1)
unet_output = unet(x_noisy, times, **unet_kwargs)
self_cond, _ = self.parse_unet_output(learned_variance, unet_output)
self_cond = self_cond.detach()
# forward to get model prediction
model_output = unet(
unet_output = unet(
x_noisy,
times,
**unet_kwargs,
@@ -2872,10 +2910,7 @@ class Decoder(nn.Module):
text_cond_drop_prob = self.text_cond_drop_prob,
)
if learned_variance:
pred, _ = model_output.chunk(2, dim = 1)
else:
pred = model_output
pred, _ = self.parse_unet_output(learned_variance, unet_output)
target = noise if not predict_x_start else x_start
@@ -2898,7 +2933,7 @@ class Decoder(nn.Module):
# if learning the variance, also include the extra weight kl loss
true_mean, _, true_log_variance_clipped = noise_scheduler.q_posterior(x_start = x_start, x_t = x_noisy, t = times)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = model_output)
model_mean, _, model_log_variance, _ = self.p_mean_variance(unet, x = x_noisy, t = times, image_embed = image_embed, noise_scheduler = noise_scheduler, clip_denoised = clip_denoised, learned_variance = True, model_output = unet_output)
# kl loss with detached model predicted mean, for stability reasons as in paper

View File

@@ -1 +1 @@
__version__ = '1.6.0'
__version__ = '1.6.5'