mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 02:34:19 +01:00
complete conditional dropout mask creation for both prior network as well as image decoder unet for classifier free guidance
This commit is contained in:
@@ -42,6 +42,16 @@ def freeze_model_and_make_eval_(model):
|
|||||||
model.eval()
|
model.eval()
|
||||||
freeze_all_layers_(model)
|
freeze_all_layers_(model)
|
||||||
|
|
||||||
|
# classifier free guidance functions
|
||||||
|
|
||||||
|
def prob_mask_like(shape, prob, device):
|
||||||
|
if prob == 1:
|
||||||
|
return torch.ones(shape, device = device, dtype = torch.bool)
|
||||||
|
elif prob == 0:
|
||||||
|
return torch.zeros(shape, device = device, dtype = torch.bool)
|
||||||
|
else:
|
||||||
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
|
||||||
|
|
||||||
# diffusion prior
|
# diffusion prior
|
||||||
|
|
||||||
class RMSNorm(nn.Module):
|
class RMSNorm(nn.Module):
|
||||||
@@ -179,7 +189,7 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
mask = None,
|
mask = None,
|
||||||
cond_drop_prob = 0.2
|
cond_drop_prob = 0.2
|
||||||
):
|
):
|
||||||
batch = image_embed.shape[0]
|
batch, text_enc_len, device = image_embed.shape[0], text_encodings.shape[-2], image_embed.device
|
||||||
|
|
||||||
# in section 2.2, last paragraph
|
# in section 2.2, last paragraph
|
||||||
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
# "... consisting of encoded text, CLIP text embedding, diffusion timestep embedding, noised CLIP image embedding, final embedding for prediction"
|
||||||
@@ -200,6 +210,18 @@ class DiffusionPriorNetwork(nn.Module):
|
|||||||
learned_queries
|
learned_queries
|
||||||
), dim = -2)
|
), dim = -2)
|
||||||
|
|
||||||
|
# mask if it doesn't exist
|
||||||
|
|
||||||
|
if not exists(mask):
|
||||||
|
mask = torch.ones((batch, text_enc_len), device = device, dtype = torch.bool)
|
||||||
|
|
||||||
|
# classifier free guidance
|
||||||
|
|
||||||
|
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
||||||
|
mask &= rearrange(cond_prob_mask, 'b -> b 1')
|
||||||
|
|
||||||
|
# attend
|
||||||
|
|
||||||
tokens = self.causal_transformer(tokens, mask = mask)
|
tokens = self.causal_transformer(tokens, mask = mask)
|
||||||
|
|
||||||
# get learned query, which should predict the image embedding (per DDPM timestep)
|
# get learned query, which should predict the image embedding (per DDPM timestep)
|
||||||
@@ -358,8 +380,11 @@ class Unet(nn.Module):
|
|||||||
text_encodings = None,
|
text_encodings = None,
|
||||||
cond_prob_drop = 0.2
|
cond_prob_drop = 0.2
|
||||||
):
|
):
|
||||||
|
batch_size, device = image_embed.shape[0], image_embed.device
|
||||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||||
|
|
||||||
|
cond_prob_mask = prob_mask_like(batch_size, cond_prob_drop, device = device)
|
||||||
|
|
||||||
hiddens = []
|
hiddens = []
|
||||||
|
|
||||||
for convnext, convnext2, downsample in self.downs:
|
for convnext, convnext2, downsample in self.downs:
|
||||||
|
|||||||
Reference in New Issue
Block a user