From 2eac7996faef08803230c7f3b2f9199e07c33b39 Mon Sep 17 00:00:00 2001 From: Nasir Khalid Date: Sat, 7 May 2022 17:32:33 -0400 Subject: [PATCH] Additional image_embed metric (#75) Added metric to track image_embed vs predicted_image_embed --- train_diffusion_prior.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index f486caf..657db7e 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -93,6 +93,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz text_embed, predicted_image_embeddings).cpu().numpy() unrelated_similarity = cos( text_embed, predicted_unrelated_embeddings).cpu().numpy() + predicted_img_similarity = cos( + test_image_embeddings, predicted_image_embeddings).cpu().numpy() wandb.log( {"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) @@ -100,6 +102,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz predicted_similarity)}) wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean( unrelated_similarity)}) + wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean( + predicted_img_similarity)}) return np.mean(predicted_similarity - original_similarity)