mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Fix passing of l2norm_output to DiffusionPriorNetwork (#51)
This commit is contained in:
@@ -71,7 +71,8 @@ def train(image_embed_dim,
|
|||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads).to(device)
|
heads = dpn_heads,
|
||||||
|
l2norm_output = dp_l2norm_output).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
@@ -79,7 +80,6 @@ def train(image_embed_dim,
|
|||||||
clip = clip,
|
clip = clip,
|
||||||
image_embed_dim = image_embed_dim,
|
image_embed_dim = image_embed_dim,
|
||||||
timesteps = dp_timesteps,
|
timesteps = dp_timesteps,
|
||||||
l2norm_output = dp_l2norm_output,
|
|
||||||
cond_drop_prob = dp_cond_drop_prob,
|
cond_drop_prob = dp_cond_drop_prob,
|
||||||
loss_type = dp_loss_type,
|
loss_type = dp_loss_type,
|
||||||
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
condition_on_text_encodings = dp_condition_on_text_encodings).to(device)
|
||||||
|
|||||||
Reference in New Issue
Block a user