mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
small cleanup with timer
This commit is contained in:
11
dalle2_pytorch/utils.py
Normal file
11
dalle2_pytorch/utils.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.last_time = time.time()
|
||||||
|
|
||||||
|
def elapsed(self):
|
||||||
|
return time.time() - self.last_time
|
||||||
@@ -2,8 +2,9 @@ from dalle2_pytorch import Unet, Decoder
|
|||||||
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
from dalle2_pytorch.trainer import DecoderTrainer, print_ribbon
|
||||||
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
from dalle2_pytorch.dataloaders import create_image_embedding_dataloader
|
||||||
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
from dalle2_pytorch.trackers import WandbTracker, ConsoleTracker
|
||||||
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
from configs.decoder_defaults import default_config, ConfigField
|
from configs.decoder_defaults import default_config, ConfigField
|
||||||
import time
|
|
||||||
import json
|
import json
|
||||||
import torchvision
|
import torchvision
|
||||||
from torchvision import transforms as T
|
from torchvision import transforms as T
|
||||||
@@ -260,14 +261,17 @@ def train(
|
|||||||
|
|
||||||
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
send_to_device = lambda arr: [x.to(device=inference_device, dtype=torch.float) for x in arr]
|
||||||
step = start_step
|
step = start_step
|
||||||
|
|
||||||
for epoch in range(start_epoch, epochs):
|
for epoch in range(start_epoch, epochs):
|
||||||
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
print(print_ribbon(f"Starting epoch {epoch}", repeat=40))
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
timer = Timer()
|
||||||
|
|
||||||
sample = 0
|
sample = 0
|
||||||
last_sample = 0
|
last_sample = 0
|
||||||
last_snapshot = 0
|
last_snapshot = 0
|
||||||
last_time = time.time()
|
|
||||||
losses = []
|
losses = []
|
||||||
for i, (img, emb) in enumerate(dataloaders["train"]):
|
for i, (img, emb) in enumerate(dataloaders["train"]):
|
||||||
step += 1
|
step += 1
|
||||||
@@ -281,8 +285,9 @@ def train(
|
|||||||
trainer.update(unet_number=unet)
|
trainer.update(unet_number=unet)
|
||||||
losses.append(loss)
|
losses.append(loss)
|
||||||
|
|
||||||
samples_per_sec = (sample - last_sample) / (time.time() - last_time)
|
samples_per_sec = (sample - last_sample) / timer.elapsed()
|
||||||
last_time = time.time()
|
|
||||||
|
timer.reset()
|
||||||
last_sample = sample
|
last_sample = sample
|
||||||
|
|
||||||
if i % 10 == 0:
|
if i % 10 == 0:
|
||||||
@@ -320,7 +325,7 @@ def train(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
sample = 0
|
sample = 0
|
||||||
average_loss = 0
|
average_loss = 0
|
||||||
start_time = time.time()
|
timer = Timer()
|
||||||
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
for i, (img, emb, txt) in enumerate(dataloaders["val"]):
|
||||||
sample += img.shape[0]
|
sample += img.shape[0]
|
||||||
img, emb = send_to_device((img, emb))
|
img, emb = send_to_device((img, emb))
|
||||||
@@ -330,12 +335,13 @@ def train(
|
|||||||
average_loss += loss
|
average_loss += loss
|
||||||
|
|
||||||
if i % 10 == 0:
|
if i % 10 == 0:
|
||||||
print(f"Epoch {epoch}/{epochs} - {sample / (time.time() - start_time):.2f} samples/sec")
|
print(f"Epoch {epoch}/{epochs} - {sample / timer.elapsed():.2f} samples/sec")
|
||||||
print(f"Loss: {average_loss / (i+1)}")
|
print(f"Loss: {average_loss / (i+1)}")
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
if validation_samples is not None and sample >= validation_samples:
|
if validation_samples is not None and sample >= validation_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
average_loss /= i+1
|
average_loss /= i+1
|
||||||
log_data = {
|
log_data = {
|
||||||
"Validation loss": average_loss
|
"Validation loss": average_loss
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import click
|
import click
|
||||||
import math
|
import math
|
||||||
import time
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
|
|||||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||||
|
|
||||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||||
|
from dalle2_pytorch.utils import Timer
|
||||||
|
|
||||||
from embedding_reader import EmbeddingReader
|
from embedding_reader import EmbeddingReader
|
||||||
|
|
||||||
@@ -29,16 +29,6 @@ tracker = WandbTracker()
|
|||||||
def exists(val):
|
def exists(val):
|
||||||
val is not None
|
val is not None
|
||||||
|
|
||||||
class Timer:
|
|
||||||
def __init__(self):
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.last_time = time.time()
|
|
||||||
|
|
||||||
def elapsed(self):
|
|
||||||
return time.time() - self.last_time
|
|
||||||
|
|
||||||
# functions
|
# functions
|
||||||
|
|
||||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||||
|
|||||||
Reference in New Issue
Block a user