small cleanup with timer

This commit is contained in:
Phil Wang
2022-05-20 20:05:01 -07:00
parent 022c94e443
commit 8997f178d6
3 changed files with 25 additions and 18 deletions

11
dalle2_pytorch/utils.py Normal file
View File

@@ -0,0 +1,11 @@
import time
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time

View File

@@ -2,8 +2,9 @@ from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.utils import Timer
from configs.decoder_defaults import default_config, ConfigField
import time
import json
import torchvision
from torchvision import transforms as T
@@ -260,14 +261,17 @@ def train(
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step
for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train()
timer = Timer()
sample = 0
last_sample = 0
last_snapshot = 0
last_time = time.time()
losses = []
for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1
@@ -281,8 +285,9 @@ def train(
trainer.update(unet_number=unet)
losses.append(loss)
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
last_time = time.time()
samples_per_sec = (sample - last_sample) / timer.elapsed()
timer.reset()
last_sample = sample
if i % 10 == 0:
@@ -320,7 +325,7 @@ def train(
with torch.no_grad():
sample = 0
average_loss = 0
start_time = time.time()
timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0]
img, emb = send_to_device((img, emb))
@@ -330,12 +335,13 @@ def train(
average_loss += loss
if i % 10 == 0:
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}")
print("")
if validation_samples is not None and sample >= validation_samples:
break
average_loss /= i+1
log_data = {
"Validation loss": average_loss
@@ -497,4 +503,4 @@ def main(config_file):
if __name__ == "__main__":
main()
main()

View File

@@ -1,7 +1,6 @@
from pathlib import Path
import click
import math
import time
import numpy as np
import torch
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from embedding_reader import EmbeddingReader
@@ -29,16 +29,6 @@ tracker = WandbTracker()
def exists(val):
val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):