diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index ce2439e..e0327e7 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -53,6 +53,7 @@ def train(image_embed_dim, clip, dp_condition_on_text_encodings, dp_timesteps, + dp_l2norm_output, dp_cond_drop_prob, dpn_depth, dpn_dim_head, @@ -78,6 +79,7 @@ def train(image_embed_dim, clip = clip, image_embed_dim = image_embed_dim, timesteps = dp_timesteps, + l2norm_output = dp_l2norm_output, cond_drop_prob = dp_cond_drop_prob, loss_type = dp_loss_type, condition_on_text_encodings = dp_condition_on_text_encodings).to(device) @@ -180,6 +182,7 @@ def main(): # DiffusionPrior(dp) parameters parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-timesteps", type=int, default=100) + parser.add_argument("--dp-l2norm-output", type=bool, default=False) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--clip", type=str, default=None) @@ -223,6 +226,7 @@ def main(): args.clip, args.dp_condition_on_text_encodings, args.dp_timesteps, + args.dp_l2norm_output, args.dp_cond_drop_prob, args.dpn_depth, args.dpn_dim_head,