From cd5f2c1de41eb8ed7cf3c12bc50cdf4fd001e22e Mon Sep 17 00:00:00 2001 From: z <51308183+nousr@users.noreply.github.com> Date: Sat, 7 May 2022 05:34:59 -0700 Subject: [PATCH] simulate unrelated captions as a training metric (#66) * add unrelated embedding metric * change to torch.roll Co-authored-by: nousr Co-authored-by: nousr <> --- train_diffusion_prior.py | 54 ++++++++++++++++++++++++++++++++-------- 1 file changed, 43 insertions(+), 11 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index a607209..f486caf 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -46,28 +46,60 @@ def save_model(save_path, state_dict): print("====================================== Saving checkpoint ======================================") torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') -def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device): + +def report_cosine_sims(diffusion_prior, image_reader, text_reader, train_set_size, val_set_size, NUM_TEST_EMBEDDINGS, device): cos = nn.CosineSimilarity(dim=1, eps=1e-6) tstart = train_set_size+val_set_size tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS - for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)): + for embt, embi in zip(text_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend), image_reader(batch_size=NUM_TEST_EMBEDDINGS, start=tstart, end=tend)): + # make a copy of the text embeddings for shuffling text_embed = torch.tensor(embt[0]).to(device) - text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) - test_text_cond = dict(text_embed = text_embed) + text_embed_shuffled = text_embed.clone() + # roll the text embeddings to simulate "unrelated" captions + rolled_idx = torch.roll(torch.arange(NUM_TEST_EMBEDDINGS), 1) + text_embed_shuffled = text_embed_shuffled[rolled_idx] + text_embed_shuffled = text_embed_shuffled / \ + text_embed_shuffled.norm(dim=1, keepdim=True) + test_text_shuffled_cond = dict(text_embed=text_embed_shuffled) + + # prepare the text embedding + text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) + test_text_cond = dict(text_embed=text_embed) + + # prepare image embeddings test_image_embeddings = torch.tensor(embi[0]).to(device) - test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True) + test_image_embeddings = test_image_embeddings / \ + test_image_embeddings.norm(dim=1, keepdim=True) - predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond) - predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True) + # predict on the unshuffled text embeddings + predicted_image_embeddings = diffusion_prior.p_sample_loop( + (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_cond) + predicted_image_embeddings = predicted_image_embeddings / \ + predicted_image_embeddings.norm(dim=1, keepdim=True) - original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy() - predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy() + # predict on the shuffled embeddings + predicted_unrelated_embeddings = diffusion_prior.p_sample_loop( + (NUM_TEST_EMBEDDINGS, 768), text_cond=test_text_shuffled_cond) + predicted_unrelated_embeddings = predicted_unrelated_embeddings / \ + predicted_unrelated_embeddings.norm(dim=1, keepdim=True) - wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) - wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)}) + # calculate similarities + original_similarity = cos( + text_embed, test_image_embeddings).cpu().numpy() + predicted_similarity = cos( + text_embed, predicted_image_embeddings).cpu().numpy() + unrelated_similarity = cos( + text_embed, predicted_unrelated_embeddings).cpu().numpy() + + wandb.log( + {"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) + wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)": np.mean( + predicted_similarity)}) + wandb.log({"CosineSimilarity(text_embed,predicted_unrelated_embed)": np.mean( + unrelated_similarity)}) return np.mean(predicted_similarity - original_similarity)