mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-21 10:44:18 +01:00
add ability to train diffusion prior with l2norm on output image embed
This commit is contained in:
@@ -53,6 +53,7 @@ def train(image_embed_dim,
|
|||||||
clip,
|
clip,
|
||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
|
dp_l2norm_output,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
dpn_dim_head,
|
dpn_dim_head,
|
||||||
@@ -78,6 +79,7 @@ def train(image_embed_dim,
|
|||||||
clip = clip,
|
clip = clip,
|
||||||
image_embed_dim = image_embed_dim,
|
image_embed_dim = image_embed_dim,
|
||||||
timesteps = dp_timesteps,
|
timesteps = dp_timesteps,
|
||||||
|
l2norm_output = dp_l2norm_output,
|
||||||
cond_drop_prob = dp_cond_drop_prob,
|
cond_drop_prob = dp_cond_drop_prob,
|
||||||
loss_type = dp_loss_type,
|
loss_type = dp_loss_type,
|
||||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||||
@@ -180,6 +182,7 @@ def main():
|
|||||||
# DiffusionPrior(dp) parameters
|
# DiffusionPrior(dp) parameters
|
||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||||
|
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
@@ -223,6 +226,7 @@ def main():
|
|||||||
args.clip,
|
args.clip,
|
||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
|
args.dp_l2norm_output,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
args.dpn_dim_head,
|
args.dpn_dim_head,
|
||||||
|
|||||||
Reference in New Issue
Block a user