mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
small cleanup with timer
This commit is contained in:
11
dalle2_pytorch/utils.py
Normal file
11
dalle2_pytorch/utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import time
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
@@ -2,8 +2,9 @@ from dalle2_pytorch import Unet, Decoder
|
||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from configs.decoder_defaults import default_config, ConfigField
|
||||
import time
|
||||
import json
|
||||
import torchvision
|
||||
from torchvision import transforms as T
|
||||
@@ -260,14 +261,17 @@ def train(
|
||||
|
||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||
step = start_step
|
||||
|
||||
for epoch in range(start_epoch, epochs):
|
||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||
trainer.train()
|
||||
|
||||
timer = Timer()
|
||||
|
||||
sample = 0
|
||||
last_sample = 0
|
||||
last_snapshot = 0
|
||||
last_time = time.time()
|
||||
|
||||
losses = []
|
||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||
step += 1
|
||||
@@ -281,8 +285,9 @@ def train(
|
||||
trainer.update(unet_number=unet)
|
||||
losses.append(loss)
|
||||
|
||||
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
|
||||
last_time = time.time()
|
||||
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||
|
||||
timer.reset()
|
||||
last_sample = sample
|
||||
|
||||
if i % 10 == 0:
|
||||
@@ -320,7 +325,7 @@ def train(
|
||||
with torch.no_grad():
|
||||
sample = 0
|
||||
average_loss = 0
|
||||
start_time = time.time()
|
||||
timer = Timer()
|
||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||
sample += img.shape[0]
|
||||
img, emb = send_to_device((img, emb))
|
||||
@@ -330,12 +335,13 @@ def train(
|
||||
average_loss += loss
|
||||
|
||||
if i % 10 == 0:
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
|
||||
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||
print(f"Loss: {average_loss / (i+1)}")
|
||||
print("")
|
||||
|
||||
if validation_samples is not None and sample >= validation_samples:
|
||||
break
|
||||
|
||||
average_loss /= i+1
|
||||
log_data = {
|
||||
"Validation loss": average_loss
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
@@ -29,16 +29,6 @@ tracker = WandbTracker()
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
|
||||
Reference in New Issue
Block a user