diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index a45199d..3c2bf80 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -17,14 +17,24 @@ os.environ["WANDB_SILENT"] = "true" def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): model.eval() with torch.no_grad(): - for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), + total_loss = 0. + total_samples = 0. + + for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), text_reader(batch_size=batch_size, start=start, end=end)): + emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) + + batches = emb_images_tensor.shape[0] + loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) - # Log to wandb - wandb.log({f'{phase} {loss_type}': loss}) + total_loss += loss.item() * batches + total_samples += batches + + avg_loss = (total_loss / total_samples) + wandb.log({f'{phase} {loss_type}': avg_loss}) def save_model(save_path,state_dict): # Saving State Dict