diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 681df5d..136c4e7 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -42,6 +42,16 @@ def freeze_model_and_make_eval_(model): model.eval() freeze_all_layers_(model) +# classifier free guidance functions + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device = device, dtype = torch.bool) + elif prob == 0: + return torch.zeros(shape, device = device, dtype = torch.bool) + else: + return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob + # diffusion prior class RMSNorm(nn.Module): @@ -179,7 +189,7 @@ class DiffusionPriorNetwork(nn.Module): mask = None, cond_drop_prob = 0.2 ): - batch = image_embed.shape[0] + batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device # in section 2.2, last paragraph # "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction" @@ -200,6 +210,18 @@ class DiffusionPriorNetwork(nn.Module): learned_queries ), dim = -2) + # mask if it doesn't exist + + if not exists(mask): + mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool) + + # classifier free guidance + + cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) + mask &= rearrange(cond_prob_mask, 'b -> b 1') + + # attend + tokens = self.causal_transformer(tokens, mask = mask) # get learned query, which should predict the image embedding (per DDPM timestep) @@ -358,8 +380,11 @@ class Unet(nn.Module): text_encodings = None, cond_prob_drop = 0.2 ): + batch_size, device = image_embed.shape[0], image_embed.device t = self.time_mlp(time) if exists(self.time_mlp) else None + cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device) + hiddens = [] for convnext, convnext2, downsample in self.downs: