From a774bfefe27ceb6171ad486165739014cf1ee9f0 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 9 May 2022 13:57:15 -0700 Subject: [PATCH] add attention and feedforward dropouts to train_diffusion_prior script --- train_diffusion_prior.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index af79ec0..7e1f10b 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -119,6 +119,7 @@ def train(image_embed_dim, learning_rate=0.001, max_grad_norm=0.5, weight_decay=0.01, + dropout=0.05, amp=False): # DiffusionPriorNetwork @@ -127,6 +128,8 @@ def train(image_embed_dim, depth = dpn_depth, dim_head = dpn_dim_head, heads = dpn_heads, + attn_dropout = dropout, + ff_dropout = dropout, normformer = dp_normformer).to(device) # DiffusionPrior with text embeddings and image embeddings pre-computed @@ -244,6 +247,7 @@ def main(): # Hyperparameters parser.add_argument("--learning-rate", type=float, default=1.1e-4) parser.add_argument("--weight-decay", type=float, default=6.02e-2) + parser.add_argument("--dropout", type=float, default=5e-2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--num-epochs", type=int, default=5) @@ -339,6 +343,7 @@ def main(): args.learning_rate, args.max_grad_norm, args.weight_decay, + args.dropout, args.amp) if __name__ == "__main__":