small cleanup with timer

This commit is contained in:
Phil Wang
2022-05-20 20:05:01 -07:00
parent 022c94e443
commit 8997f178d6
3 changed files with 25 additions and 18 deletions

11
dalle2_pytorch/utils.py Normal file
View File

@@ -0,0 +1,11 @@
import time
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time

View File

@@ -2,8 +2,9 @@ from dalle2_pytorch import Unet, Decoder
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
from dalle2_pytorch.utils import Timer
from configs.decoder_defaults import default_config, ConfigField from configs.decoder_defaults import default_config, ConfigField
import time
import json import json
import torchvision import torchvision
from torchvision import transforms as T from torchvision import transforms as T
@@ -260,14 +261,17 @@ def train(
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr] send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
step = start_step step = start_step
for epoch in range(start_epoch, epochs): for epoch in range(start_epoch, epochs):
print(print_ribbon(f"Starting epoch {epoch}", repeat=40)) print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
trainer.train() trainer.train()
timer = Timer()
sample = 0 sample = 0
last_sample = 0 last_sample = 0
last_snapshot = 0 last_snapshot = 0
last_time = time.time()
losses = [] losses = []
for i, (img, emb) in enumerate(dataloaders["train"]): for i, (img, emb) in enumerate(dataloaders["train"]):
step += 1 step += 1
@@ -281,8 +285,9 @@ def train(
trainer.update(unet_number=unet) trainer.update(unet_number=unet)
losses.append(loss) losses.append(loss)
samples_per_sec = (sample - last_sample) / (time.time() - last_time) samples_per_sec = (sample - last_sample) / timer.elapsed()
last_time = time.time()
timer.reset()
last_sample = sample last_sample = sample
if i % 10 == 0: if i % 10 == 0:
@@ -320,7 +325,7 @@ def train(
with torch.no_grad(): with torch.no_grad():
sample = 0 sample = 0
average_loss = 0 average_loss = 0
start_time = time.time() timer = Timer()
for i, (img, emb, txt) in enumerate(dataloaders["val"]): for i, (img, emb, txt) in enumerate(dataloaders["val"]):
sample += img.shape[0] sample += img.shape[0]
img, emb = send_to_device((img, emb)) img, emb = send_to_device((img, emb))
@@ -330,12 +335,13 @@ def train(
average_loss += loss average_loss += loss
if i % 10 == 0: if i % 10 == 0:
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec") print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
print(f"Loss: {average_loss / (i+1)}") print(f"Loss: {average_loss / (i+1)}")
print("") print("")
if validation_samples is not None and sample >= validation_samples: if validation_samples is not None and sample >= validation_samples:
break break
average_loss /= i+1 average_loss /= i+1
log_data = { log_data = {
"Validation loss": average_loss "Validation loss": average_loss
@@ -497,4 +503,4 @@ def main(config_file):
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,7 +1,6 @@
from pathlib import Path from pathlib import Path
import click import click
import math import math
import time
import numpy as np import numpy as np
import torch import torch
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
from dalle2_pytorch.utils import Timer
from embedding_reader import EmbeddingReader from embedding_reader import EmbeddingReader
@@ -29,16 +29,6 @@ tracker = WandbTracker()
def exists(val): def exists(val):
val is not None val is not None
class Timer:
def __init__(self):
self.reset()
def reset(self):
self.last_time = time.time()
def elapsed(self):
return time.time() - self.last_time
# functions # functions
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"): def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):