Fix passing of l2norm_output to DiffusionPriorNetwork (#51)

This commit is contained in:
Romain Beaumont
2022-05-02 19:48:16 +02:00
committed by GitHub
parent 3fe96c208a
commit 2d25c89f35

View File

@@ -71,7 +71,8 @@ def train(image_embed_dim,
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads).to(device) heads = dpn_heads,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
@@ -79,7 +80,6 @@ def train(image_embed_dim,
clip = clip, clip = clip,
image_embed_dim = image_embed_dim, image_embed_dim = image_embed_dim,
timesteps = dp_timesteps, timesteps = dp_timesteps,
l2norm_output = dp_l2norm_output,
cond_drop_prob = dp_cond_drop_prob, cond_drop_prob = dp_cond_drop_prob,
loss_type = dp_loss_type, loss_type = dp_loss_type,
condition_on_text_encodings = dp_condition_on_text_encodings).to(device) condition_on_text_encodings = dp_condition_on_text_encodings).to(device)