diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index f486caf..657db7e 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -93,6 +93,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz text_embed, predicted_image_embeddings).cpu().numpy() unrelated_similarity = cos( text_embed, predicted_unrelated_embeddings).cpu().numpy() + predicted_img_similarity = cos( + test_image_embeddings, predicted_image_embeddings).cpu().numpy() wandb.log( {"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) @@ -100,6 +102,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz predicted_similarity)}) wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean( unrelated_similarity)}) + wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean( + predicted_img_similarity)}) return np.mean(predicted_similarity - original_similarity)