mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
Additional image_embed metric (#75)
Added metric to track image_embed vs predicted_image_embed
This commit is contained in:
@@ -93,6 +93,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
|
|||||||
text_embed, predicted_image_embeddings).cpu().numpy()
|
text_embed, predicted_image_embeddings).cpu().numpy()
|
||||||
unrelated_similarity = cos(
|
unrelated_similarity = cos(
|
||||||
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
text_embed, predicted_unrelated_embeddings).cpu().numpy()
|
||||||
|
predicted_img_similarity = cos(
|
||||||
|
test_image_embeddings, predicted_image_embeddings).cpu().numpy()
|
||||||
|
|
||||||
wandb.log(
|
wandb.log(
|
||||||
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
{"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)})
|
||||||
@@ -100,6 +102,8 @@ def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_siz
|
|||||||
predicted_similarity)})
|
predicted_similarity)})
|
||||||
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
|
wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean(
|
||||||
unrelated_similarity)})
|
unrelated_similarity)})
|
||||||
|
wandb.log({"CosineSimilarity(image_embed,predicted_image_embed)": np.mean(
|
||||||
|
predicted_img_similarity)})
|
||||||
|
|
||||||
return np.mean(predicted_similarity - original_similarity)
|
return np.mean(predicted_similarity - original_similarity)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user