Compare commits

..

3 Commits
1.6.0 ... 1.6.2

Author SHA1 Message Date
Phil Wang
301a97197f fix self conditioning shape in diffusion prior 2022-08-12 12:29:25 -07:00
Phil Wang
9440411954 make self conditioning technique work with diffusion prior 2022-08-12 12:20:51 -07:00
Phil Wang
981d407792 comment 2022-08-12 11:41:23 -07:00
2 changed files with 43 additions and 11 deletions

View File

@@ -937,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
num_image_embeds = 1,
num_text_embeds = 1,
max_text_len = 256,
self_cond = False,
**kwargs
):
super().__init__()
self.dim = dim
self.num_time_embeds = num_time_embeds
self.num_image_embeds = num_image_embeds
self.num_text_embeds = num_text_embeds
@@ -967,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
self.max_text_len = max_text_len
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
# whether to use self conditioning, Hinton's group's new ddpm technique
self.self_cond = self_cond
def forward_with_cond_scale(
self,
*args,
@@ -988,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
*,
text_embed,
text_encodings = None,
self_cond = None,
cond_drop_prob = 0.
):
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
# setup self conditioning
if self.self_cond:
self_cond = default(self_cond, lambda: torch.zeros(batch, self.dim, device = device, dtype = dtype))
self_cond = rearrange(self_cond, 'b d -> b 1 d')
# in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -1043,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
time_embed = self.to_time_embeds(diffusion_timesteps)
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
if self.self_cond:
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
tokens = torch.cat((
text_encodings,
text_embed,
@@ -1151,10 +1168,10 @@ class DiffusionPrior(nn.Module):
def l2norm_clamp_embed(self, image_embed):
return l2norm(image_embed) * self.image_embed_scale
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1168,28 +1185,33 @@ class DiffusionPrior(nn.Module):
x_start = l2norm(x_start) * self.image_embed_scale
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return pred, x_start
@torch.no_grad()
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
batch, device = shape[0], self.device
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
times = torch.full((batch,), i, device = device, dtype = torch.long)
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
self_cond = x_start if self.net.self_cond else None
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
if self.sampling_final_clamp_l2norm and self.predict_x_start:
image_embed = self.l2norm_clamp_embed(image_embed)
@@ -1207,6 +1229,8 @@ class DiffusionPrior(nn.Module):
image_embed = torch.randn(shape, device = device)
x_start = None # for self-conditioning
if self.init_image_embed_l2norm:
image_embed = l2norm(image_embed) * self.image_embed_scale
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
self_cond = x_start if self.net.self_cond else None
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
if self.predict_x_start:
x_start = pred
@@ -1260,9 +1286,15 @@ class DiffusionPrior(nn.Module):
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
self_cond = None
if self.net.self_cond and random.random() < 1.5:
with torch.no_grad():
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
pred = self.net(
image_embed_noisy,
times,
self_cond = self_cond,
cond_drop_prob = self.cond_drop_prob,
**text_cond
)
@@ -1700,7 +1732,7 @@ class Unet(nn.Module):
attn_heads = 16,
lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
lowres_noise_cond = False, # for conditioning on low resolution noising, based on Imagen
self_cond = False,
self_cond = False, # set this to True to use the self-conditioning technique from - https://arxiv.org/abs/2208.04202
sparse_attn = False,
cosine_sim_cross_attn = False,
cosine_sim_self_attn = False,

View File

@@ -1 +1 @@
__version__ = '1.6.0'
__version__ = '1.6.2'