mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Fix passing of l2norm_output to DiffusionPriorNetwork (#51)
This commit is contained in:
@@ -71,7 +71,8 @@ def train(image_embed_dim,
|
||||
dim = image_embed_dim,
|
||||
depth = dpn_depth,
|
||||
dim_head = dpn_dim_head,
|
||||
heads = dpn_heads).to(device)
|
||||
heads = dpn_heads,
|
||||
l2norm_output = dp_l2norm_output).to(device)
|
||||
|
||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||
diffusion_prior = DiffusionPrior(
|
||||
@@ -79,7 +80,6 @@ def train(image_embed_dim,
|
||||
clip = clip,
|
||||
image_embed_dim = image_embed_dim,
|
||||
timesteps = dp_timesteps,
|
||||
l2norm_output = dp_l2norm_output,
|
||||
cond_drop_prob = dp_cond_drop_prob,
|
||||
loss_type = dp_loss_type,
|
||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||
|
||||
Reference in New Issue
Block a user