make self conditioning technique work with diffusion prior

This commit is contained in:
Phil Wang
2022-08-12 12:20:51 -07:00
parent 981d407792
commit 9440411954
2 changed files with 42 additions and 10 deletions

View File

@@ -937,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
self_cond = False,
**kwargs
):
super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
@@ -967,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale(
self,
*args,
@@ -988,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
self_cond = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
self_cond = None
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -1043,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat((
text_encodings,
text_embed,
@@ -1151,10 +1168,10 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1168,28 +1185,33 @@ class DiffusionPrior(nn.Module):
x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1207,6 +1229,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1260,9 +1286,15 @@ class DiffusionPrior(nn.Module):
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 0.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net(
image_embed_noisy,
times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)

View File

@@ -1 +1 @@
__version__ = '1.6.0'
__version__ = '1.6.1'