mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 02:34:19 +01:00
complete conditional dropout mask creation for both prior network as well as image decoder unet for classifier free guidance
This commit is contained in:
@@ -42,6 +42,16 @@ def freeze_model_and_make_eval_(model):
|
||||
model.eval()
|
||||
freeze_all_layers_(model)
|
||||
|
||||
# classifier free guidance functions
|
||||
|
||||
def prob_mask_like(shape, prob, device):
|
||||
if prob == 1:
|
||||
return torch.ones(shape, device = device, dtype = torch.bool)
|
||||
elif prob == 0:
|
||||
return torch.zeros(shape, device = device, dtype = torch.bool)
|
||||
else:
|
||||
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
||||
|
||||
# diffusion prior
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
@@ -179,7 +189,7 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
mask = None,
|
||||
cond_drop_prob = 0.2
|
||||
):
|
||||
batch = image_embed.shape[0]
|
||||
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
||||
|
||||
# in section 2.2, last paragraph
|
||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||
@@ -200,6 +210,18 @@ class DiffusionPriorNetwork(nn.Module):
|
||||
learned_queries
|
||||
), dim = -2)
|
||||
|
||||
# mask if it doesn't exist
|
||||
|
||||
if not exists(mask):
|
||||
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
||||
|
||||
# classifier free guidance
|
||||
|
||||
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
||||
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||
|
||||
# attend
|
||||
|
||||
tokens = self.causal_transformer(tokens, mask = mask)
|
||||
|
||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||
@@ -358,8 +380,11 @@ class Unet(nn.Module):
|
||||
text_encodings = None,
|
||||
cond_prob_drop = 0.2
|
||||
):
|
||||
batch_size, device = image_embed.shape[0], image_embed.device
|
||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||
|
||||
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
||||
|
||||
hiddens = []
|
||||
|
||||
for convnext, convnext2, downsample in self.downs:
|
||||
|
||||
Reference in New Issue
Block a user