From f7df3caaf314c81f377ffd45f97381e4ae74c110 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Mon, 2 May 2022 08:51:41 -0700 Subject: [PATCH] address not calculating average eval / test loss when training diffusion prior https://github.com/lucidrains/DALLE2-pytorch/issues/49 --- train_diffusion_prior.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index a45199d..3c2bf80 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -17,14 +17,24 @@ os.environ["WANDB_SILENT"] = "true" def eval_model(model,device,image_reader,text_reader,start,end,batch_size,loss_type,phase="Validation"): model.eval() with torch.no_grad(): - for emb_images,emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), + total_loss = 0. + total_samples = 0. + + for emb_images, emb_text in zip(image_reader(batch_size=batch_size, start=start, end=end), text_reader(batch_size=batch_size, start=start, end=end)): + emb_images_tensor = torch.tensor(emb_images[0]).to(device) emb_text_tensor = torch.tensor(emb_text[0]).to(device) + + batches = emb_images_tensor.shape[0] + loss = model(text_embed = emb_text_tensor, image_embed = emb_images_tensor) - # Log to wandb - wandb.log({f'{phase} {loss_type}': loss}) + total_loss += loss.item() * batches + total_samples += batches + + avg_loss = (total_loss / total_samples) + wandb.log({f'{phase} {loss_type}': avg_loss}) def save_model(save_path,state_dict): # Saving State Dict