mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2026-02-19 12:24:39 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
11469dc0c6 | ||
|
|
2d25c89f35 | ||
|
|
3fe96c208a |
@@ -599,7 +599,7 @@ class CausalTransformer(nn.Module):
|
|||||||
dim_head = 64,
|
dim_head = 64,
|
||||||
heads = 8,
|
heads = 8,
|
||||||
ff_mult = 4,
|
ff_mult = 4,
|
||||||
norm_out = False,
|
norm_out = True,
|
||||||
attn_dropout = 0.,
|
attn_dropout = 0.,
|
||||||
ff_dropout = 0.,
|
ff_dropout = 0.,
|
||||||
final_proj = True
|
final_proj = True
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -10,7 +10,7 @@ setup(
|
|||||||
'dream = dalle2_pytorch.cli:dream'
|
'dream = dalle2_pytorch.cli:dream'
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
version = '0.0.91',
|
version = '0.0.92',
|
||||||
license='MIT',
|
license='MIT',
|
||||||
description = 'DALL-E 2',
|
description = 'DALL-E 2',
|
||||||
author = 'Phil Wang',
|
author = 'Phil Wang',
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ def train(image_embed_dim,
|
|||||||
clip,
|
clip,
|
||||||
dp_condition_on_text_encodings,
|
dp_condition_on_text_encodings,
|
||||||
dp_timesteps,
|
dp_timesteps,
|
||||||
|
dp_l2norm_output,
|
||||||
dp_cond_drop_prob,
|
dp_cond_drop_prob,
|
||||||
dpn_depth,
|
dpn_depth,
|
||||||
dpn_dim_head,
|
dpn_dim_head,
|
||||||
@@ -70,7 +71,8 @@ def train(image_embed_dim,
|
|||||||
dim = image_embed_dim,
|
dim = image_embed_dim,
|
||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads).to(device)
|
heads = dpn_heads,
|
||||||
|
l2norm_output = dp_l2norm_output).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
diffusion_prior = DiffusionPrior(
|
diffusion_prior = DiffusionPrior(
|
||||||
@@ -180,6 +182,7 @@ def main():
|
|||||||
# DiffusionPrior(dp) parameters
|
# DiffusionPrior(dp) parameters
|
||||||
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
|
||||||
parser.add_argument("--dp-timesteps", type=int, default=100)
|
parser.add_argument("--dp-timesteps", type=int, default=100)
|
||||||
|
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
|
||||||
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.2)
|
||||||
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
parser.add_argument("--dp-loss-type", type=str, default="l2")
|
||||||
parser.add_argument("--clip", type=str, default=None)
|
parser.add_argument("--clip", type=str, default=None)
|
||||||
@@ -223,6 +226,7 @@ def main():
|
|||||||
args.clip,
|
args.clip,
|
||||||
args.dp_condition_on_text_encodings,
|
args.dp_condition_on_text_encodings,
|
||||||
args.dp_timesteps,
|
args.dp_timesteps,
|
||||||
|
args.dp_l2norm_output,
|
||||||
args.dp_cond_drop_prob,
|
args.dp_cond_drop_prob,
|
||||||
args.dpn_depth,
|
args.dpn_depth,
|
||||||
args.dpn_dim_head,
|
args.dpn_dim_head,
|
||||||
|
|||||||
Reference in New Issue
Block a user