add attention and feedforward dropouts to train_diffusion_prior script

This commit is contained in:
Phil Wang
2022-05-09 13:57:15 -07:00
parent 2ae57f0cf5
commit a774bfefe2

View File

@@ -119,6 +119,7 @@ def train(image_embed_dim,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
dropout=0.05,
amp=False):
# DiffusionPriorNetwork
@@ -127,6 +128,8 @@ def train(image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -244,6 +247,7 @@ def main():
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--dropout", type=float, default=5e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
@@ -339,6 +343,7 @@ def main():
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.dropout,
args.amp)
if __name__ == "__main__":