From 2d25c89f358877c434c9c09de4380ceaa91f53d4 Mon Sep 17 00:00:00 2001 From: Romain Beaumont Date: Mon, 2 May 2022 19:48:16 +0200 Subject: [PATCH] Fix passing of l2norm_output to DiffusionPriorNetwork (#51) --- train_diffusion_prior.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index e0327e7..1d541f1 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -71,7 +71,8 @@ def train(image_embed_dim, dim = image_embed_dim, depth = dpn_depth, dim_head = dpn_dim_head, - heads = dpn_heads).to(device) + heads = dpn_heads, + l2norm_output = dp_l2norm_output).to(device) # DiffusionPrior with text embeddings and image embeddings pre-computed diffusion_prior = DiffusionPrior( @@ -79,7 +80,6 @@ def train(image_embed_dim, clip = clip, image_embed_dim = image_embed_dim, timesteps = dp_timesteps, - l2norm_output = dp_l2norm_output, cond_drop_prob = dp_cond_drop_prob, loss_type = dp_loss_type, condition_on_text_encodings = dp_condition_on_text_encodings).to(device)