From 2d9963d30e43f4feb68135f0a29cfa77db141b13 Mon Sep 17 00:00:00 2001 From: Kumar R Date: Wed, 4 May 2022 20:34:36 +0530 Subject: [PATCH] Reporting metrics - Cosine similarity. (#55) * Update train_diffusion_prior.py * Delete train_diffusion_prior.py * Cosine similarity logging. * Update train_diffusion_prior.py * Report Cosine metrics every N steps. --- train_diffusion_prior.py | 45 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 43 insertions(+), 2 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index ee66659..8ba58d3 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,21 +1,23 @@ import os import math import argparse +import numpy as np import torch from torch import nn from embedding_reader import EmbeddingReader from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork from dalle2_pytorch.optimizer import get_optimizer -from dalle2_pytorch.optimizer import get_optimizer from torch.cuda.amp import autocast,GradScaler - import time from tqdm import tqdm import wandb os.environ["WANDB_SILENT"] = "true" +NUM_TEST_EMBEDDINGS = 100 # for cosine similarity reporting during training +REPORT_METRICS_EVERY = 100 # for cosine similarity and other metric reporting during training + def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): model.eval() @@ -44,6 +46,33 @@ def save_model(save_path, state_dict): print("====================================== Saving checkpoint ======================================") torch.save(state_dict, save_path+'/'+str(time.time())+'_saved_model.pth') +def report_cosine_sims(diffusion_prior,image_reader,text_reader,train_set_size,val_set_size,NUM_TEST_EMBEDDINGS,device): + cos = nn.CosineSimilarity(dim=1, eps=1e-6) + + tstart = train_set_size+val_set_size + tend = train_set_size+val_set_size+NUM_TEST_EMBEDDINGS + + for embt, embi in zip(text_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend),image_reader(batch_size = NUM_TEST_EMBEDDINGS, start=tstart, end = tend)): + text_embed = torch.tensor(embt[0]).to(device) + text_embed = text_embed / text_embed.norm(dim=1, keepdim=True) + test_text_cond = dict(text_embed = text_embed) + + test_image_embeddings = torch.tensor(embi[0]).to(device) + test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(dim=1, keepdim=True) + + predicted_image_embeddings = diffusion_prior.p_sample_loop((NUM_TEST_EMBEDDINGS, 768), text_cond = test_text_cond) + predicted_image_embeddings = predicted_image_embeddings / predicted_image_embeddings.norm(dim=1, keepdim=True) + + original_similarity = cos(text_embed,test_image_embeddings).cpu().numpy() + predicted_similarity = cos(text_embed,predicted_image_embeddings).cpu().numpy() + + wandb.log({"CosineSimilarity(text_embed,image_embed)": np.mean(original_similarity)}) + wandb.log({"CosineSimilarity(text_embed,predicted_image_embed)":np.mean(predicted_similarity)}) + + return np.mean(predicted_similarity - original_similarity) + + + def train(image_embed_dim, image_embed_url, text_embed_url, @@ -137,6 +166,18 @@ def train(image_embed_dim, wandb.log({"Training loss": loss.item(), "Steps": step, "Samples per second": samples_per_sec}) + # Log cosineSim(text_embed,predicted_image_embed) - cosineSim(text_embed,image_embed) + # Use NUM_TEST_EMBEDDINGS samples from the test set each time + # Get embeddings from the most recently saved model + if(step % REPORT_METRICS_EVERY) == 0: + diff_cosine_sim = report_cosine_sims(diffusion_prior, + image_reader, + text_reader, + train_set_size, + val_set_size, + NUM_TEST_EMBEDDINGS, + device) + wandb.log({"Cosine similarity difference": diff_cosine_sim}) scaler.unscale_(optimizer) nn.utils.clip_grad_norm_(diffusion_prior.parameters(), max_grad_norm)