allow laion to experiment with normformer in diffusion prior

This commit is contained in:
Phil Wang
2022-05-02 11:35:00 -07:00
parent 70282de23b
commit aa8d135245

View File

@@ -54,6 +54,7 @@ def train(image_embed_dim,
dp_condition_on_text_encodings, dp_condition_on_text_encodings,
dp_timesteps, dp_timesteps,
dp_l2norm_output, dp_l2norm_output,
dp_normformer,
dp_cond_drop_prob, dp_cond_drop_prob,
dpn_depth, dpn_depth,
dpn_dim_head, dpn_dim_head,
@@ -72,6 +73,7 @@ def train(image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads, heads = dpn_heads,
normformer = dp_normformer,
l2norm_output = dp_l2norm_output).to(device) l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -183,6 +185,7 @@ def main():
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False) parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None) parser.add_argument("--clip", type=str, default=None)
@@ -227,6 +230,7 @@ def main():
args.dp_condition_on_text_encodings, args.dp_condition_on_text_encodings,
args.dp_timesteps, args.dp_timesteps,
args.dp_l2norm_output, args.dp_l2norm_output,
args.dp_normformer,
args.dp_cond_drop_prob, args.dp_cond_drop_prob,
args.dpn_depth, args.dpn_depth,
args.dpn_dim_head, args.dpn_dim_head,