diff --git a/train_diffusion_prior.py b/train_diffusion_prior.py index 4ab2d09..77575d1 100644 --- a/train_diffusion_prior.py +++ b/train_diffusion_prior.py @@ -24,6 +24,9 @@ tracker = WandbTracker() # helpers functions +def exists(val): + val is not None + class Timer: def __init__(self): self.reset() @@ -167,9 +170,6 @@ def train(image_embed_dim, if RESUME: diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device) - - # TODO, optimizer and scaler needs to be loaded as well - tracker.init(entity = wandb_entity, project = wandb_project, config = config) # diffusion prior trainer @@ -353,15 +353,12 @@ def main( } } - RESUME = False - # Check if DPRIOR_PATH exists(saved model path) DPRIOR_PATH = args.pretrained_model_path + RESUME = exists(DPRIOR_PATH) - if(DPRIOR_PATH is not None): - RESUME = True - else: + if not RESUME: tracker.init( entity = wandb_entity, project = wandb_project,