diff --git a/dalle2_pytorch/utils.py b/dalle2_pytorch/utils.py new file mode 100644 index 0000000..f7cbf86 --- /dev/null +++ b/dalle2_pytorch/utils.py @@ -0,0 +1,11 @@ +import time + +class Timer: + def __init__(self): + self.reset() + + def reset(self): + self.last_time = time.time() + + def elapsed(self): + return time.time() - self.last_time diff --git a/train_decoder.py b/train_decoder.py index 3c91c36..e3daa5e 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -2,8 +2,9 @@ from dalle2_pytorch import Unet, Decoder from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker +from dalle2_pytorch.utils import Timer + from configs.decoder_defaults import default_config, ConfigField -import time import json import torchvision from torchvision import transforms as T @@ -260,14 +261,17 @@ def train( send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr] step = start_step + for epoch in range(start_epoch, epochs): print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) trainer.train() + timer = Timer() + sample = 0 last_sample = 0 last_snapshot = 0 - last_time = time.time() + losses = [] for i, (img, emb) in enumerate(dataloaders["train"]): step += 1 @@ -281,8 +285,9 @@ def train( trainer.update(unet_number=unet) losses.append(loss) - samples_per_sec = (sample - last_sample) / (time.time() - last_time) - last_time = time.time() + samples_per_sec = (sample - last_sample) / timer.elapsed() + + timer.reset() last_sample = sample if i % 10 == 0: @@ -320,7 +325,7 @@ def train( with torch.no_grad(): sample = 0 average_loss = 0 - start_time = time.time() + timer = Timer() for i, (img, emb, txt) in enumerate(dataloaders["val"]): sample += img.shape[0] img, emb = send_to_device((img, emb)) @@ -330,12 +335,13 @@ def train( average_loss += loss if i % 10 == 0: - print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec") + print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec") print(f"Loss: {average_loss / (i+1)}") print("") if validation_samples is not None and sample >= validation_samples: break + average_loss /= i+1 log_data = { "Validation loss": average_loss @@ -497,4 +503,4 @@ def main(config_file): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index bb07a69..c7eb1a1 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -1,7 +1,6 @@ from pathlib import Path import click import math -import time import numpy as np import torch @@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker +from dalle2_pytorch.utils import Timer from embedding_reader import EmbeddingReader @@ -29,16 +29,6 @@ tracker = WandbTracker() def exists(val): val is not None -class Timer: - def __init__(self): - self.reset() - - def reset(self): - self.last_time = time.time() - - def elapsed(self): - return time.time() - self.last_time - # functions def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):