diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index e0327e7..1d541f1 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -71,7 +71,8 @@ def train(image_embed_dim, dim = image_embed_dim, depth = dpn_depth, dim_head = dpn_dim_head, - heads = dpn_heads).to(device) + heads = dpn_heads, + l2norm_output = dp_l2norm_output).to(device) # DiffusionPrior with text embeddings and image embeddings pre-computed diffusion_prior = DiffusionPrior( @@ -79,7 +80,6 @@ def train(image_embed_dim, clip = clip, image_embed_dim = image_embed_dim, timesteps = dp_timesteps, - l2norm_output = dp_l2norm_output, cond_drop_prob = dp_cond_drop_prob, loss_type = dp_loss_type, condition_on_text_encodings = dp_condition_on_text_encodings).to(device)