mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
prior train script bug fixes (#153)
This commit is contained in:
@@ -12,7 +12,6 @@ import wandb
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import normalize
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import numpy as np
|
||||
@@ -68,13 +67,13 @@ def eval_model(
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
loss_type: str,
|
||||
phase: str,
|
||||
tracker_context: str,
|
||||
tracker: BaseTracker = None,
|
||||
use_ema: bool = True,
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
click.secho(f"Measuring performance on {phase}", fg="green", blink=True)
|
||||
click.secho(f"Measuring performance on {tracker_context}", fg="green", blink=True)
|
||||
|
||||
with torch.no_grad():
|
||||
total_loss = 0.0
|
||||
@@ -103,7 +102,7 @@ def eval_model(
|
||||
|
||||
avg_loss = total_loss / total_samples
|
||||
|
||||
stats = {f"{phase}/{loss_type}": avg_loss}
|
||||
stats = {f"{tracker_context}-{loss_type}": avg_loss}
|
||||
trainer.print(stats)
|
||||
|
||||
if exists(tracker):
|
||||
@@ -115,7 +114,7 @@ def report_cosine_sims(
|
||||
dataloader: DataLoader,
|
||||
text_conditioned: bool,
|
||||
tracker: BaseTracker,
|
||||
phase: str = "validation",
|
||||
tracker_context: str = "validation",
|
||||
):
|
||||
trainer.eval()
|
||||
if trainer.is_main_process():
|
||||
@@ -141,7 +140,9 @@ def report_cosine_sims(
|
||||
# roll the text to simulate "unrelated" captions
|
||||
rolled_idx = torch.roll(torch.arange(text_embedding.shape[0]), 1)
|
||||
text_embed_shuffled = text_embed_shuffled[rolled_idx]
|
||||
text_embed_shuffled = text_embed_shuffled / normalize(text_embed_shuffled)
|
||||
text_embed_shuffled = text_embed_shuffled / text_embed_shuffled.norm(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
if text_conditioned:
|
||||
text_encodings_shuffled = text_encodings[rolled_idx]
|
||||
@@ -157,18 +158,21 @@ def report_cosine_sims(
|
||||
)
|
||||
|
||||
# prepare the text embedding
|
||||
text_embed = normalize(text_embedding / text_embedding)
|
||||
text_embed = text_embedding / text_embedding.norm(dim=1, keepdim=True)
|
||||
|
||||
# prepare image embeddings
|
||||
test_image_embeddings = test_image_embeddings / normalize(test_image_embeddings)
|
||||
test_image_embeddings = test_image_embeddings / test_image_embeddings.norm(
|
||||
dim=1, keepdim=True
|
||||
)
|
||||
|
||||
# predict on the unshuffled text embeddings
|
||||
predicted_image_embeddings = trainer.p_sample_loop(
|
||||
test_image_embeddings.shape, text_cond
|
||||
)
|
||||
|
||||
predicted_image_embeddings = predicted_image_embeddings / normalize(
|
||||
predicted_image_embeddings = (
|
||||
predicted_image_embeddings
|
||||
/ predicted_image_embeddings.norm(dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
# predict on the shuffled embeddings
|
||||
@@ -176,8 +180,9 @@ def report_cosine_sims(
|
||||
test_image_embeddings.shape, text_cond_shuffled
|
||||
)
|
||||
|
||||
predicted_unrelated_embeddings = predicted_unrelated_embeddings / normalize(
|
||||
predicted_unrelated_embeddings = (
|
||||
predicted_unrelated_embeddings
|
||||
/ predicted_unrelated_embeddings.norm(dim=1, keepdim=True)
|
||||
)
|
||||
|
||||
# calculate similarities
|
||||
@@ -191,19 +196,19 @@ def report_cosine_sims(
|
||||
)
|
||||
|
||||
stats = {
|
||||
f"{phase}/baseline similarity": np.mean(original_similarity),
|
||||
f"{phase}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{phase}/similarity with original image": np.mean(
|
||||
f"{tracker_context}/baseline similarity": np.mean(original_similarity),
|
||||
f"{tracker_context}/similarity with text": np.mean(predicted_similarity),
|
||||
f"{tracker_context}/similarity with original image": np.mean(
|
||||
predicted_img_similarity
|
||||
),
|
||||
f"{phase}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{phase}/difference from baseline similarity": np.mean(
|
||||
f"{tracker_context}/similarity with unrelated caption": np.mean(unrelated_similarity),
|
||||
f"{tracker_context}/difference from baseline similarity": np.mean(
|
||||
predicted_similarity - original_similarity
|
||||
),
|
||||
}
|
||||
|
||||
for k, v in stats.items():
|
||||
trainer.print(f"{phase}/{k}: {v}")
|
||||
trainer.print(f"{tracker_context}/{k}: {v}")
|
||||
|
||||
if exists(tracker):
|
||||
tracker.log(stats, step=trainer.step.item() + 1)
|
||||
@@ -269,10 +274,10 @@ def train(
|
||||
# Log on all processes for debugging
|
||||
tracker.log(
|
||||
{
|
||||
"training/loss": loss,
|
||||
"samples/sec/rank": samples_per_sec,
|
||||
"samples/seen": samples_seen,
|
||||
"ema/decay": ema_decay,
|
||||
"tracking/samples-sec": samples_per_sec,
|
||||
"tracking/samples-seen": samples_seen,
|
||||
"tracking/ema-decay": ema_decay,
|
||||
"metrics/training-loss": loss,
|
||||
},
|
||||
step=current_step,
|
||||
)
|
||||
@@ -280,12 +285,12 @@ def train(
|
||||
# Metric Tracking & Checkpointing (outside of timer's scope)
|
||||
if current_step % EVAL_EVERY == 0:
|
||||
eval_model(
|
||||
trainer,
|
||||
eval_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
config.prior.loss_type,
|
||||
"training/validation",
|
||||
tracker,
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/online-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=False,
|
||||
)
|
||||
|
||||
@@ -293,8 +298,8 @@ def train(
|
||||
trainer=trainer,
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss=config.prior.loss_type,
|
||||
phase="ema/validation",
|
||||
loss_type=config.prior.loss_type,
|
||||
tracker_context="metrics/ema-model-validation",
|
||||
tracker=tracker,
|
||||
use_ema=True,
|
||||
)
|
||||
@@ -304,7 +309,7 @@ def train(
|
||||
dataloader=eval_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
tracker=tracker,
|
||||
phase="ema/validation",
|
||||
tracker_context="metrics",
|
||||
)
|
||||
|
||||
if current_step % config.train.save_every == 0:
|
||||
@@ -320,7 +325,7 @@ def train(
|
||||
dataloader=test_loader,
|
||||
text_conditioned=config.prior.condition_on_text_encodings,
|
||||
loss_type=config.prior.loss_type,
|
||||
phase="test",
|
||||
tracker_context="test",
|
||||
tracker=tracker,
|
||||
)
|
||||
|
||||
@@ -329,7 +334,7 @@ def train(
|
||||
test_loader,
|
||||
config.prior.condition_on_text_encodings,
|
||||
tracker,
|
||||
phase="test",
|
||||
tracker_context="test",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user