diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index fad469a..fb6770d 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -169,6 +169,7 @@ class DiffusionPriorNetwork(nn.Module): text_encodings, text_embed, mask = None, + cond_drop_prob = 0.2 ): batch = image_embed.shape[0]