complete conditional dropout mask creation for both prior network as well as image decoder unet for classifier free guidance

This commit is contained in:
Phil Wang
2022-04-12 14:04:02 -07:00
parent c814b2b278
commit d4c8373635

View File

@@ -42,6 +42,16 @@ def freeze_model_and_make_eval_(model):
model.eval() model.eval()
freeze_all_layers_(model) freeze_all_layers_(model)
# classifier free guidance functions
def prob_mask_like(shape, prob, device):
if prob == 1:
return torch.ones(shape, device = device, dtype = torch.bool)
elif prob == 0:
return torch.zeros(shape, device = device, dtype = torch.bool)
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
# diffusion prior # diffusion prior
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@@ -179,7 +189,7 @@ class DiffusionPriorNetwork(nn.Module):
mask = None, mask = None,
cond_drop_prob = 0.2 cond_drop_prob = 0.2
): ):
batch = image_embed.shape[0] batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
# in section 2.2, last paragraph # in section 2.2, last paragraph
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
@@ -200,6 +210,18 @@ class DiffusionPriorNetwork(nn.Module):
learned_queries learned_queries
), dim = -2) ), dim = -2)
# mask if it doesn't exist
if not exists(mask):
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
# classifier free guidance
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
mask &= rearrange(cond_prob_mask, 'b -> b 1')
# attend
tokens = self.causal_transformer(tokens, mask = mask) tokens = self.causal_transformer(tokens, mask = mask)
# get learned query, which should predict the image embedding (per DDPM timestep) # get learned query, which should predict the image embedding (per DDPM timestep)
@@ -358,8 +380,11 @@ class Unet(nn.Module):
text_encodings = None, text_encodings = None,
cond_prob_drop = 0.2 cond_prob_drop = 0.2
): ):
batch_size, device = image_embed.shape[0], image_embed.device
t = self.time_mlp(time) if exists(self.time_mlp) else None t = self.time_mlp(time) if exists(self.time_mlp) else None
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
hiddens = [] hiddens = []
for convnext, convnext2, downsample in self.downs: for convnext, convnext2, downsample in self.downs: