mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
small cleanup with timer
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
from pathlib import Path
|
||||
import click
|
||||
import math
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
@@ -13,6 +12,7 @@ from dalle2_pytorch import DiffusionPrior, DiffusionPriorNetwork, OpenAIClipAdap
|
||||
from dalle2_pytorch.trainer import DiffusionPriorTrainer, load_diffusion_model, save_diffusion_model, print_ribbon
|
||||
|
||||
from dalle2_pytorch.trackers import ConsoleTracker, WandbTracker
|
||||
from dalle2_pytorch.utils import Timer
|
||||
|
||||
from embedding_reader import EmbeddingReader
|
||||
|
||||
@@ -29,16 +29,6 @@ tracker = WandbTracker()
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.last_time = time.time()
|
||||
|
||||
def elapsed(self):
|
||||
return time.time() - self.last_time
|
||||
|
||||
# functions
|
||||
|
||||
def eval_model(model, dataloader, text_conditioned, loss_type, phase="Validation"):
|
||||
|
||||
Reference in New Issue
Block a user