prior train script bug fixes (#153)

This commit is contained in:
zion
2022-06-19 17:55:15 -05:00
committed by GitHub
parent 0215237fc6
commit f5a906f5d3

View File

@@ -12,7 +12,6 @@ import wandb
import torch
from torch import nn
from torch.nn.functional import normalize
from torch.utils.data import DataLoader
import numpy as np
@@ -68,13 +67,13 @@ def eval_model(
dataloader: DataLoader,
text_conditioned: bool,
loss_type: str,
phase: str,
tracker_context: str,
tracker: BaseTracker = None,
use_ema: bool = True,
):
trainer.eval()
if trainer.is_main_process():
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
with torch.no_grad():
total_loss = 0.0
@@ -103,7 +102,7 @@ def eval_model(
avg_loss = total_loss / total_samples
stats = {f"{phase}/{loss_type}": avg_loss}
stats = {f"{tracker_context}-{loss_type}": avg_loss}
trainer.print(stats)
if exists(tracker):
@@ -115,7 +114,7 @@ def report_cosine_sims(
dataloader: DataLoader,
text_conditioned: bool,
tracker: BaseTracker,
phase: str = "validation",
tracker_context: str = "validation",
):
trainer.eval()
if trainer.is_main_process():
@@ -141,7 +140,9 @@ def report_cosine_sims(
# roll the text to simulate "unrelated" captions
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
text_embed_shuffled = text_embed_shuffled[rolled_idx]
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
dim=1, keepdim=True
)
if text_conditioned:
text_encodings_shuffled = text_encodings[rolled_idx]
@@ -157,18 +158,21 @@ def report_cosine_sims(
)
# prepare the text embedding
text_embed = normalize(text_embedding / text_embedding)
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
# prepare image embeddings
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
dim=1, keepdim=True
)
# predict on the unshuffled text embeddings
predicted_image_embeddings = trainer.p_sample_loop(
test_image_embeddings.shape, text_cond
)
predicted_image_embeddings = predicted_image_embeddings / normalize(
predicted_image_embeddings = (
predicted_image_embeddings
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
)
# predict on the shuffled embeddings
@@ -176,8 +180,9 @@ def report_cosine_sims(
test_image_embeddings.shape, text_cond_shuffled
)
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
predicted_unrelated_embeddings = (
predicted_unrelated_embeddings
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
)
# calculate similarities
@@ -191,19 +196,19 @@ def report_cosine_sims(
)
stats = {
f"{phase}/baseline similarity": np.mean(original_similarity),
f"{phase}/similarity with text": np.mean(predicted_similarity),
f"{phase}/similarity with original image": np.mean(
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
f"{tracker_context}/similarity with original image": np.mean(
predicted_img_similarity
),
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{phase}/difference from baseline similarity": np.mean(
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
f"{tracker_context}/difference from baseline similarity": np.mean(
predicted_similarity - original_similarity
),
}
for k, v in stats.items():
trainer.print(f"{phase}/{k}: {v}")
trainer.print(f"{tracker_context}/{k}: {v}")
if exists(tracker):
tracker.log(stats, step=trainer.step.item() + 1)
@@ -269,10 +274,10 @@ def train(
# Log on all processes for debugging
tracker.log(
{
"training/loss": loss,
"samples/sec/rank": samples_per_sec,
"samples/seen": samples_seen,
"ema/decay": ema_decay,
"tracking/samples-sec": samples_per_sec,
"tracking/samples-seen": samples_seen,
"tracking/ema-decay": ema_decay,
"metrics/training-loss": loss,
},
step=current_step,
)
@@ -280,12 +285,12 @@ def train(
# Metric Tracking & Checkpointing (outside of timer's scope)
if current_step % EVAL_EVERY == 0:
eval_model(
trainer,
eval_loader,
config.prior.condition_on_text_encodings,
config.prior.loss_type,
"training/validation",
tracker,
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
tracker_context="metrics/online-model-validation",
tracker=tracker,
use_ema=False,
)
@@ -293,8 +298,8 @@ def train(
trainer=trainer,
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss=config.prior.loss_type,
phase="ema/validation",
loss_type=config.prior.loss_type,
tracker_context="metrics/ema-model-validation",
tracker=tracker,
use_ema=True,
)
@@ -304,7 +309,7 @@ def train(
dataloader=eval_loader,
text_conditioned=config.prior.condition_on_text_encodings,
tracker=tracker,
phase="ema/validation",
tracker_context="metrics",
)
if current_step % config.train.save_every == 0:
@@ -320,7 +325,7 @@ def train(
dataloader=test_loader,
text_conditioned=config.prior.condition_on_text_encodings,
loss_type=config.prior.loss_type,
phase="test",
tracker_context="test",
tracker=tracker,
)
@@ -329,7 +334,7 @@ def train(
test_loader,
config.prior.condition_on_text_encodings,
tracker,
phase="test",
tracker_context="test",
)