From 81d83dd7f209ca54d158cbc6d169006f6b240d08 Mon Sep 17 00:00:00 2001 From: z <51308183+nousr@users.noreply.github.com> Date: Mon, 2 May 2022 13:52:11 -0700 Subject: [PATCH] defaults align with paper (#52) Co-authored-by: nousr <> --- train_diffusion_prior.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 660e12b..5ccb6d3 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -165,8 +165,8 @@ def main(): parser.add_argument("--image-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/img_emb/") parser.add_argument("--text-embed-url", type=str, default="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion2B-en/text_emb/") # Hyperparameters - parser.add_argument("--learning-rate", type=float, default=0.001) - parser.add_argument("--weight-decay", type=float, default=0.01) + parser.add_argument("--learning-rate", type=float, default=1.1e-4) + parser.add_argument("--weight-decay", type=float, default=6.02e-2) parser.add_argument("--max-grad-norm", type=float, default=0.5) parser.add_argument("--batch-size", type=int, default=10**4) parser.add_argument("--num-epochs", type=int, default=5) @@ -186,7 +186,7 @@ def main(): parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-l2norm-output", type=bool, default=False) parser.add_argument("--dp-normformer", type=bool, default=False) - parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) + parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1) parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--clip", type=str, default=None) parser.add_argument("--amp", type=bool, default=False)