add ability to train diffusion prior with l2norm on output image embed

This commit is contained in:
Phil Wang
2022-05-02 09:53:20 -07:00
parent 0fc6c9cdf3
commit 3fe96c208a

View File

@@ -53,6 +53,7 @@ def train(image_embed_dim,
clip, clip,
dp_condition_on_text_encodings, dp_condition_on_text_encodings,
dp_timesteps, dp_timesteps,
dp_l2norm_output,
dp_cond_drop_prob, dp_cond_drop_prob,
dpn_depth, dpn_depth,
dpn_dim_head, dpn_dim_head,
@@ -78,6 +79,7 @@ def train(image_embed_dim,
clip = clip, clip = clip,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
timesteps = dp_timesteps, timesteps = dp_timesteps,
l2norm_output = dp_l2norm_output,
cond_drop_prob = dp_cond_drop_prob, cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
@@ -180,6 +182,7 @@ def main():
# DiffusionPrior(dp) parameters # DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None) parser.add_argument("--clip", type=str, default=None)
@@ -223,6 +226,7 @@ def main():
args.clip, args.clip,
args.dp_condition_on_text_encodings, args.dp_condition_on_text_encodings,
args.dp_timesteps, args.dp_timesteps,
args.dp_l2norm_output,
args.dp_cond_drop_prob, args.dp_cond_drop_prob,
args.dpn_depth, args.dpn_depth,
args.dpn_dim_head, args.dpn_dim_head,