mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove l2norm output from train_diffusion_prior.py
This commit is contained in:
@@ -85,7 +85,6 @@ def train(image_embed_dim,
|
|||||||
clip,
|
clip,
|
||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
dp_l2norm_output,
|
|
||||||
dp_normformer,
|
dp_normformer,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
@@ -105,8 +104,7 @@ def train(image_embed_dim,
|
|||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads,
|
heads = dpn_heads,
|
||||||
normformer = dp_normformer,
|
normformer = dp_normformer).to(device)
|
||||||
l2norm_output = dp_l2norm_output).to(device)
|
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
@@ -273,7 +271,6 @@ def main():
|
|||||||
args.clip,
|
args.clip,
|
||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
args.dp_l2norm_output,
|
|
||||||
args.dp_normformer,
|
args.dp_normformer,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
|
|||||||
Reference in New Issue
Block a user