From f5a906f5d35896bd10afa4f5b06107321493fcb2 Mon Sep 17 00:00:00 2001 From: zion <51308183+nousr@users.noreply.github.com> Date: Sun, 19 Jun 2022 17:55:15 -0500 Subject: [PATCH] prior train script bug fixes (#153) --- train_diffusion_prior.py | 67 +++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index f5e0a15..02f2ce2 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -12,7 +12,6 @@ import wandb import torch from torch import nn -from torch.nn.functional import normalize from torch.utils.data import DataLoader import numpy as np @@ -68,13 +67,13 @@ def eval_model( dataloader: DataLoader, text_conditioned: bool, loss_type: str, - phase: str, + tracker_context: str, tracker: BaseTracker = None, use_ema: bool = True, ): trainer.eval() if trainer.is_main_process(): - click.secho(f"Measuring performance on {phase}", fg="green", blink=True) + click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True) with torch.no_grad(): total_loss = 0.0 @@ -103,7 +102,7 @@ def eval_model( avg_loss = total_loss / total_samples - stats = {f"{phase}/{loss_type}": avg_loss} + stats = {f"{tracker_context}-{loss_type}": avg_loss} trainer.print(stats) if exists(tracker): @@ -115,7 +114,7 @@ def report_cosine_sims( dataloader: DataLoader, text_conditioned: bool, tracker: BaseTracker, - phase: str = "validation", + tracker_context: str = "validation", ): trainer.eval() if trainer.is_main_process(): @@ -141,7 +140,9 @@ def report_cosine_sims( # roll the text to simulate "unrelated" captions rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1) text_embed_shuffled = text_embed_shuffled[rolled_idx] - text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled) + text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm( + dim=1, keepdim=True + ) if text_conditioned: text_encodings_shuffled = text_encodings[rolled_idx] @@ -157,18 +158,21 @@ def report_cosine_sims( ) # prepare the text embedding - text_embed = normalize(text_embedding / text_embedding) + text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True) # prepare image embeddings - test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings) + test_image_embeddings = test_image_embeddings / test_image_embeddings.norm( + dim=1, keepdim=True + ) # predict on the unshuffled text embeddings predicted_image_embeddings = trainer.p_sample_loop( test_image_embeddings.shape, text_cond ) - predicted_image_embeddings = predicted_image_embeddings / normalize( + predicted_image_embeddings = ( predicted_image_embeddings + / predicted_image_embeddings.norm(dim=1, keepdim=True) ) # predict on the shuffled embeddings @@ -176,8 +180,9 @@ def report_cosine_sims( test_image_embeddings.shape, text_cond_shuffled ) - predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize( + predicted_unrelated_embeddings = ( predicted_unrelated_embeddings + / predicted_unrelated_embeddings.norm(dim=1, keepdim=True) ) # calculate similarities @@ -191,19 +196,19 @@ def report_cosine_sims( ) stats = { - f"{phase}/baseline similarity": np.mean(original_similarity), - f"{phase}/similarity with text": np.mean(predicted_similarity), - f"{phase}/similarity with original image": np.mean( + f"{tracker_context}/baseline similarity": np.mean(original_similarity), + f"{tracker_context}/similarity with text": np.mean(predicted_similarity), + f"{tracker_context}/similarity with original image": np.mean( predicted_img_similarity ), - f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity), - f"{phase}/difference from baseline similarity": np.mean( + f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity), + f"{tracker_context}/difference from baseline similarity": np.mean( predicted_similarity - original_similarity ), } for k, v in stats.items(): - trainer.print(f"{phase}/{k}: {v}") + trainer.print(f"{tracker_context}/{k}: {v}") if exists(tracker): tracker.log(stats, step=trainer.step.item() + 1) @@ -269,10 +274,10 @@ def train( # Log on all processes for debugging tracker.log( { - "training/loss": loss, - "samples/sec/rank": samples_per_sec, - "samples/seen": samples_seen, - "ema/decay": ema_decay, + "tracking/samples-sec": samples_per_sec, + "tracking/samples-seen": samples_seen, + "tracking/ema-decay": ema_decay, + "metrics/training-loss": loss, }, step=current_step, ) @@ -280,12 +285,12 @@ def train( # Metric Tracking & Checkpointing (outside of timer's scope) if current_step % EVAL_EVERY == 0: eval_model( - trainer, - eval_loader, - config.prior.condition_on_text_encodings, - config.prior.loss_type, - "training/validation", - tracker, + trainer=trainer, + dataloader=eval_loader, + text_conditioned=config.prior.condition_on_text_encodings, + loss_type=config.prior.loss_type, + tracker_context="metrics/online-model-validation", + tracker=tracker, use_ema=False, ) @@ -293,8 +298,8 @@ def train( trainer=trainer, dataloader=eval_loader, text_conditioned=config.prior.condition_on_text_encodings, - loss=config.prior.loss_type, - phase="ema/validation", + loss_type=config.prior.loss_type, + tracker_context="metrics/ema-model-validation", tracker=tracker, use_ema=True, ) @@ -304,7 +309,7 @@ def train( dataloader=eval_loader, text_conditioned=config.prior.condition_on_text_encodings, tracker=tracker, - phase="ema/validation", + tracker_context="metrics", ) if current_step % config.train.save_every == 0: @@ -320,7 +325,7 @@ def train( dataloader=test_loader, text_conditioned=config.prior.condition_on_text_encodings, loss_type=config.prior.loss_type, - phase="test", + tracker_context="test", tracker=tracker, ) @@ -329,7 +334,7 @@ def train( test_loader, config.prior.condition_on_text_encodings, tracker, - phase="test", + tracker_context="test", )