Compare commits

...

3 Commits
0.2.1 ... 0.2.2

Author SHA1 Message Date
z
cb07b37970 Ensure Eval Mode In Metric Functions (#79)
* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
2022-05-09 16:05:40 -07:00
Phil Wang
a774bfefe2 add attention and feedforward dropouts to train_diffusion_prior script 2022-05-09 13:57:15 -07:00
Phil Wang
2ae57f0cf5 cleanup 2022-05-09 13:51:26 -07:00

View File

@@ -42,6 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
wandb.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
@@ -119,6 +120,7 @@ def train(image_embed_dim,
learning_rate=0.001,
max_grad_norm=0.5,
weight_decay=0.01,
dropout=0.05,
amp=False):
# DiffusionPriorNetwork
@@ -127,6 +129,8 @@ def train(image_embed_dim,
depth = dpn_depth,
dim_head = dpn_dim_head,
heads = dpn_heads,
attn_dropout = dropout,
ff_dropout = dropout,
normformer = dp_normformer).to(device)
# DiffusionPrior with text embeddings and image embeddings pre-computed
@@ -167,10 +171,12 @@ def train(image_embed_dim,
eval_start = train_set_size
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)
@@ -244,6 +250,7 @@ def main():
# Hyperparameters
parser.add_argument("--learning-rate", type=float, default=1.1e-4)
parser.add_argument("--weight-decay", type=float, default=6.02e-2)
parser.add_argument("--dropout", type=float, default=5e-2)
parser.add_argument("--max-grad-norm", type=float, default=0.5)
parser.add_argument("--batch-size", type=int, default=10**4)
parser.add_argument("--num-epochs", type=int, default=5)
@@ -261,7 +268,6 @@ def main():
# DiffusionPrior(dp) parameters
parser.add_argument("--dp-condition-on-text-encodings", type=bool, default=False)
parser.add_argument("--dp-timesteps", type=int, default=100)
parser.add_argument("--dp-l2norm-output", type=bool, default=False)
parser.add_argument("--dp-normformer", type=bool, default=False)
parser.add_argument("--dp-cond-drop-prob", type=float, default=0.1)
parser.add_argument("--dp-loss-type", type=str, default="l2")
@@ -340,6 +346,7 @@ def main():
args.learning_rate,
args.max_grad_norm,
args.weight_decay,
args.dropout,
args.amp)
if __name__ == "__main__":