refactor so that the causal transformer in the diffusion prior network can be conditioned without text encodings (for Laions parallel efforts, although it seems from the paper it is needed)

This commit is contained in:
Phil Wang
2022-04-26 09:00:11 -07:00
parent 4d25976f33
commit 76e063e8b7
2 changed files with 30 additions and 17 deletions

View File

@@ -421,25 +421,41 @@ class DiffusionPriorNetwork(nn.Module):
image_embed, image_embed,
diffusion_timesteps, diffusion_timesteps,
*, *,
text_encodings,
text_embed, text_embed,
text_encodings = None,
mask = None, mask = None,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
): ):
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device batch, dim, device, dtype = *image_embed.shape, image_embed.device, image_embed.dtype
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d') text_embed, image_embed = rearrange_many((text_embed, image_embed), 'b d -> b 1 d')
# make text encodings optional
# although the paper seems to suggest it is present <--
if not exists(text_encodings):
text_encodings = torch.empty((batch, 0, dim), device = device, dtype = dtype)
if not exists(mask):
mask = torch.ones((batch, text_encodings.shape[-2]), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
cond_prob_mask = rearrange(cond_prob_mask, 'b -> b 1')
mask &= cond_prob_mask
# whether text embedding is masked or not depends on the classifier free guidance conditional masking
mask = torch.cat((mask, cond_prob_mask), dim = 1)
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different) # whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right # but let's just do it right
if exists(mask):
not_all_masked_out = mask.any(dim = -1)
mask = torch.cat((mask, rearrange(not_all_masked_out, 'b -> b 1')), dim = 1)
if exists(mask): if exists(mask):
mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query mask = F.pad(mask, (0, 2), value = True) # extend mask for text embedding, noised image embedding, time step embedding, and learned query
@@ -455,16 +471,6 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries learned_queries
), dim = -2) ), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like((batch,), cond_drop_prob, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend # attend
tokens = self.causal_transformer(tokens, mask = mask) tokens = self.causal_transformer(tokens, mask = mask)
@@ -486,6 +492,7 @@ class DiffusionPrior(nn.Module):
loss_type = "l1", loss_type = "l1",
predict_x_start = True, predict_x_start = True,
beta_schedule = "cosine", beta_schedule = "cosine",
condition_on_text_encodings = True, # the paper suggests this is needed, but you can turn it off for your CLIP preprocessed text embed -> image embed training
): ):
super().__init__() super().__init__()
assert isinstance(clip, CLIP) assert isinstance(clip, CLIP)
@@ -496,7 +503,9 @@ class DiffusionPrior(nn.Module):
self.image_embed_dim = clip.dim_latent self.image_embed_dim = clip.dim_latent
self.channels = clip.image_channels self.channels = clip.image_channels
self.image_size = clip.image_size self.image_size = clip.image_size
self.cond_drop_prob = cond_drop_prob self.cond_drop_prob = cond_drop_prob
self.condition_on_text_encodings = condition_on_text_encodings
self.predict_x_start = predict_x_start self.predict_x_start = predict_x_start
# in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both. # in paper, they do not predict the noise, but predict x0 directly for image embedding, claiming empirically better results. I'll just offer both.
@@ -561,6 +570,10 @@ class DiffusionPrior(nn.Module):
text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:] text_cls, text_encodings = text_encodings[:, 0], text_encodings[:, 1:]
text_embed = self.clip.to_text_latent(text_cls) text_embed = self.clip.to_text_latent(text_cls)
text_embed = l2norm(text_embed) text_embed = l2norm(text_embed)
if not self.condition_on_text_encodings:
return dict(text_embed = text_embed)
return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0) return dict(text_encodings = text_encodings, text_embed = text_embed, mask = text != 0)
def q_mean_variance(self, x_start, t): def q_mean_variance(self, x_start, t):

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.46', version = '0.0.47',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',