mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
make self conditioning technique work with diffusion prior
This commit is contained in:
@@ -937,9 +937,12 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
num_image_embeds = 1,
|
||||
num_text_embeds = 1,
|
||||
max_text_len = 256,
|
||||
self_cond = False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
self.num_time_embeds = num_time_embeds
|
||||
self.num_image_embeds = num_image_embeds
|
||||
self.num_text_embeds = num_text_embeds
|
||||
@@ -967,6 +970,10 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
self.max_text_len = max_text_len
|
||||
self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, dim))
|
||||
|
||||
# whether to use self conditioning, Hinton's group's new ddpm technique
|
||||
|
||||
self.self_cond = self_cond
|
||||
|
||||
def forward_with_cond_scale(
|
||||
self,
|
||||
*args,
|
||||
@@ -988,12 +995,19 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
*,
|
||||
text_embed,
|
||||
text_encodings = None,
|
||||
self_cond = None,
|
||||
cond_drop_prob = 0.
|
||||
):
|
||||
batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
|
||||
|
||||
num_time_embeds, num_image_embeds, num_text_embeds = self.num_time_embeds, self.num_image_embeds, self.num_text_embeds
|
||||
|
||||
# setup self conditioning
|
||||
|
||||
self_cond = None
|
||||
if self.self_cond:
|
||||
self_cond = default(self_cond, lambda: torch.zeros(batch, 1, self.dim, device = device, dtype = dtype))
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
|
||||
@@ -1043,13 +1057,16 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
|
||||
# but let's just do it right
|
||||
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds # 1 for learned queries + number of image embeds + time embeds
|
||||
attend_padding = 1 + num_time_embeds + num_image_embeds + int(self.self_cond) # 1 for learned queries + number of image embeds + time embeds
|
||||
mask = F.pad(mask, (0, attend_padding), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
|
||||
|
||||
time_embed = self.to_time_embeds(diffusion_timesteps)
|
||||
|
||||
learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
|
||||
|
||||
if self.self_cond:
|
||||
learned_queries = torch.cat((image_embed, self_cond), dim = -2)
|
||||
|
||||
tokens = torch.cat((
|
||||
text_encodings,
|
||||
text_embed,
|
||||
@@ -1151,10 +1168,10 @@ class DiffusionPrior(nn.Module):
|
||||
def l2norm_clamp_embed(self, image_embed):
|
||||
return l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
def p_mean_variance(self, x, t, text_cond, clip_denoised = False, cond_scale = 1.):
|
||||
def p_mean_variance(self, x, t, text_cond, self_cond = None, clip_denoised = False, cond_scale = 1.):
|
||||
assert not (cond_scale != 1. and not self.can_classifier_guidance), 'the model was not trained with conditional dropout, and thus one cannot use classifier free guidance (cond_scale anything other than 1)'
|
||||
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, **text_cond)
|
||||
pred = self.net.forward_with_cond_scale(x, t, cond_scale = cond_scale, self_cond = self_cond, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1168,28 +1185,33 @@ class DiffusionPrior(nn.Module):
|
||||
x_start = l2norm(x_start) * self.image_embed_scale
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
return model_mean, posterior_variance, posterior_log_variance, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, text_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
def p_sample(self, x, t, text_cond = None, self_cond = None, clip_denoised = True, cond_scale = 1.):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, text_cond = text_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = t, text_cond = text_cond, self_cond = self_cond, clip_denoised = clip_denoised, cond_scale = cond_scale)
|
||||
noise = torch.randn_like(x)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
pred = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return pred, x_start
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop_ddpm(self, shape, text_cond, cond_scale = 1.):
|
||||
batch, device = shape[0], self.device
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
for i in tqdm(reversed(range(0, self.noise_scheduler.num_timesteps)), desc='sampling loop time step', total=self.noise_scheduler.num_timesteps):
|
||||
times = torch.full((batch,), i, device = device, dtype = torch.long)
|
||||
image_embed = self.p_sample(image_embed, times, text_cond = text_cond, cond_scale = cond_scale)
|
||||
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
image_embed, x_start = self.p_sample(image_embed, times, text_cond = text_cond, self_cond = self_cond, cond_scale = cond_scale)
|
||||
|
||||
if self.sampling_final_clamp_l2norm and self.predict_x_start:
|
||||
image_embed = self.l2norm_clamp_embed(image_embed)
|
||||
@@ -1207,6 +1229,8 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed = torch.randn(shape, device = device)
|
||||
|
||||
x_start = None # for self-conditioning
|
||||
|
||||
if self.init_image_embed_l2norm:
|
||||
image_embed = l2norm(image_embed) * self.image_embed_scale
|
||||
|
||||
@@ -1216,7 +1240,9 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, cond_scale = cond_scale, **text_cond)
|
||||
self_cond = x_start if self.net.self_cond else None
|
||||
|
||||
pred = self.net.forward_with_cond_scale(image_embed, time_cond, self_cond = self_cond, cond_scale = cond_scale, **text_cond)
|
||||
|
||||
if self.predict_x_start:
|
||||
x_start = pred
|
||||
@@ -1260,9 +1286,15 @@ class DiffusionPrior(nn.Module):
|
||||
|
||||
image_embed_noisy = self.noise_scheduler.q_sample(x_start = image_embed, t = times, noise = noise)
|
||||
|
||||
self_cond = None
|
||||
if self.net.self_cond and random.random() < 0.5:
|
||||
with torch.no_grad():
|
||||
self_cond = self.net(image_embed_noisy, times, **text_cond).detach()
|
||||
|
||||
pred = self.net(
|
||||
image_embed_noisy,
|
||||
times,
|
||||
self_cond = self_cond,
|
||||
cond_drop_prob = self.cond_drop_prob,
|
||||
**text_cond
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user