mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
add attention and feedforward dropouts to train_diffusion_prior script
This commit is contained in:
@@ -119,6 +119,7 @@ def train(image_embed_dim,
|
|||||||
learning_rate=0.001,
|
learning_rate=0.001,
|
||||||
max_grad_norm=0.5,
|
max_grad_norm=0.5,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
|
dropout=0.05,
|
||||||
amp=False):
|
amp=False):
|
||||||
|
|
||||||
# DiffusionPriorNetwork
|
# DiffusionPriorNetwork
|
||||||
@@ -127,6 +128,8 @@ def train(image_embed_dim,
|
|||||||
depth = dpn_depth,
|
depth = dpn_depth,
|
||||||
dim_head = dpn_dim_head,
|
dim_head = dpn_dim_head,
|
||||||
heads = dpn_heads,
|
heads = dpn_heads,
|
||||||
|
attn_dropout = dropout,
|
||||||
|
ff_dropout = dropout,
|
||||||
normformer = dp_normformer).to(device)
|
normformer = dp_normformer).to(device)
|
||||||
|
|
||||||
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
# DiffusionPrior with text embeddings and image embeddings pre-computed
|
||||||
@@ -244,6 +247,7 @@ def main():
|
|||||||
# Hyperparameters
|
# Hyperparameters
|
||||||
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
|
||||||
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
|
||||||
|
parser.add_argument("--dropout", type=float, default=5e-2)
|
||||||
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
parser.add_argument("--max-grad-norm", type=float, default=0.5)
|
||||||
parser.add_argument("--batch-size", type=int, default=10**4)
|
parser.add_argument("--batch-size", type=int, default=10**4)
|
||||||
parser.add_argument("--num-epochs", type=int, default=5)
|
parser.add_argument("--num-epochs", type=int, default=5)
|
||||||
@@ -339,6 +343,7 @@ def main():
|
|||||||
args.learning_rate,
|
args.learning_rate,
|
||||||
args.max_grad_norm,
|
args.max_grad_norm,
|
||||||
args.weight_decay,
|
args.weight_decay,
|
||||||
|
args.dropout,
|
||||||
args.amp)
|
args.amp)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user