Compare commits

...

3 Commits

Author SHA1 Message Date
Phil Wang
11469dc0c6 makes more sense to keep this as True as default, for stability 2022-05-02 10:50:55 -07:00
Romain Beaumont
2d25c89f35 Fix passing of l2norm_output to DiffusionPriorNetwork (#51) 2022-05-02 10:48:16 -07:00
Phil Wang
3fe96c208a add ability to train diffusion prior with l2norm on output image embed 2022-05-02 09:53:20 -07:00
3 changed files with 7 additions and 3 deletions

View File

@@ -599,7 +599,7 @@ class CausalTransformer(nn.Module):
dim_head = 64, dim_head = 64,
heads = 8, heads = 8,
ff_mult = 4, ff_mult = 4,
norm_out = False, norm_out = True,
attn_dropout = 0., attn_dropout = 0.,
ff_dropout = 0., ff_dropout = 0.,
final_proj = True final_proj = True

View File

@@ -10,7 +10,7 @@ setup(
'dream = dalle2_pytorch.cli:dream' 'dream = dalle2_pytorch.cli:dream'
], ],
}, },
version = '0.0.91', version = '0.0.92',
license='MIT', license='MIT',
description = 'DALL-E 2', description = 'DALL-E 2',
author = 'Phil Wang', author = 'Phil Wang',

View File

@@ -53,6 +53,7 @@ def train(image_embed_dim,
clip, clip,
dp_condition_on_text_encodings, dp_condition_on_text_encodings,
dp_timesteps, dp_timesteps,
dp_l2norm_output,
dp_cond_drop_prob, dp_cond_drop_prob,
dpn_depth, dpn_depth,
dpn_dim_head, dpn_dim_head,
@@ -70,7 +71,8 @@ def train(image_embed_dim,
dim = image_embed_dim, dim = image_embed_dim,
depth = dpn_depth, depth = dpn_depth,
dim_head = dpn_dim_head, dim_head = dpn_dim_head,
heads = dpn_heads).to(device) heads = dpn_heads,
l2norm_output = dp_l2norm_output).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed # DiffusionPrior with text embeddings and image embeddings pre-computed
diffusion_prior = DiffusionPrior( diffusion_prior = DiffusionPrior(
@@ -180,6 +182,7 @@ def main():
# DiffusionPrior(dp) parameters # DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False) parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100) parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2) parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
parser.add_argument("--dp-loss-type", type=str, default="l2") parser.add_argument("--dp-loss-type", type=str, default="l2")
parser.add_argument("--clip", type=str, default=None) parser.add_argument("--clip", type=str, default=None)
@@ -223,6 +226,7 @@ def main():
args.clip, args.clip,
args.dp_condition_on_text_encodings, args.dp_condition_on_text_encodings,
args.dp_timesteps, args.dp_timesteps,
args.dp_l2norm_output,
args.dp_cond_drop_prob, args.dp_cond_drop_prob,
args.dpn_depth, args.dpn_depth,
args.dpn_dim_head, args.dpn_dim_head,