Ensure Eval Mode In Metric Functions (#79)

* add eval/train toggles

* train/eval flags

* shift train toggle

Co-authored-by: nousr <z@localhost.com>
This commit is contained in:
z
2022-05-09 16:05:40 -07:00
committed by GitHub
parent a774bfefe2
commit cb07b37970

View File

@@ -42,6 +42,7 @@ def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_t
wandb.log({f'{phase} {loss_type}': avg_loss})
def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,NUM_TEST_EMBEDDINGS,device):
diffusion_prior.eval()
cos = nn.CosineSimilarity(dim=1, eps=1e-6)
@@ -170,10 +171,12 @@ def train(image_embed_dim,
eval_start = train_set_size
for _ in range(epochs):
diffusion_prior.train()
for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=0, end=train_set_size),
text_reader(batch_size=batch_size, start=0, end=train_set_size)):
diffusion_prior.train()
emb_images_tensor = torch.tensor(emb_images[0]).to(device)
emb_text_tensor = torch.tensor(emb_text[0]).to(device)