diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 1d541f1..660e12b 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -54,6 +54,7 @@ def train(image_embed_dim, dp_condition_on_text_encodings, dp_timesteps, dp_l2norm_output, + dp_normformer, dp_cond_drop_prob, dpn_depth, dpn_dim_head, @@ -72,6 +73,7 @@ def train(image_embed_dim, depth = dpn_depth, dim_head = dpn_dim_head, heads = dpn_heads, + normformer = dp_normformer, l2norm_output = dp_l2norm_output).to(device) # DiffusionPrior with text embeddings and image embeddings pre-computed @@ -183,6 +185,7 @@ def main(): parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-l2norm-output", type=bool, default=False) + parser.add_argument("--dp-normformer", type=bool, default=False) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--clip", type=str, default=None) @@ -227,6 +230,7 @@ def main(): args.dp_condition_on_text_encodings, args.dp_timesteps, args.dp_l2norm_output, + args.dp_normformer, args.dp_cond_drop_prob, args.dpn_depth, args.dpn_dim_head,