mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
allow laion to experiment with normformer in diffusion prior
This commit is contained in:
@@ -54,6 +54,7 @@ def train(image_embed_dim,
|
|||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
dp_l2norm_output,
|
dp_l2norm_output,
|
||||||
|
dp_normformer,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
dpn_dim_head,
|
dpn_dim_head,
|
||||||
@@ -72,6 +73,7 @@ def train(image_embed_dim,
|
|||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads,
|
heads = dpn_heads,
|
||||||
|
normformer = dp_normformer,
|
||||||
l2norm_output = dp_l2norm_output).to(device)
|
l2norm_output = dp_l2norm_output).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
@@ -183,6 +185,7 @@ def main():
|
|||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||||
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||||
|
parser.add_argument("--dp-normformer", type=bool, default=False)
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
@@ -227,6 +230,7 @@ def main():
|
|||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
args.dp_l2norm_output,
|
args.dp_l2norm_output,
|
||||||
|
args.dp_normformer,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
args.dpn_dim_head,
|
args.dpn_dim_head,
|
||||||
|
|||||||
Reference in New Issue
Block a user