remove todo

This commit is contained in:
Phil Wang
2022-05-15 11:01:35 -07:00
parent aa6772dcff
commit 74f222596a

View File

@@ -24,6 +24,9 @@ tracker = WandbTracker()
# helpers functions # helpers functions
def exists(val):
val is not None
class Timer: class Timer:
def __init__(self): def __init__(self):
self.reset() self.reset()
@@ -167,9 +170,6 @@ def train(image_embed_dim,
if RESUME: if RESUME:
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device) diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
# TODO, optimizer and scaler needs to be loaded as well
tracker.init(entity = wandb_entity, project = wandb_project, config = config) tracker.init(entity = wandb_entity, project = wandb_project, config = config)
# diffusion prior trainer # diffusion prior trainer
@@ -353,15 +353,12 @@ def main(
} }
} }
RESUME = False
# Check if DPRIOR_PATH exists(saved model path) # Check if DPRIOR_PATH exists(saved model path)
DPRIOR_PATH = args.pretrained_model_path DPRIOR_PATH = args.pretrained_model_path
RESUME = exists(DPRIOR_PATH)
if(DPRIOR_PATH is not None): if not RESUME:
RESUME = True
else:
tracker.init( tracker.init(
entity = wandb_entity, entity = wandb_entity,
project = wandb_project, project = wandb_project,