mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove todo
This commit is contained in:
@@ -24,6 +24,9 @@ tracker = WandbTracker()
|
||||
|
||||
# helpers functions
|
||||
|
||||
def exists(val):
|
||||
val is not None
|
||||
|
||||
class Timer:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
@@ -167,9 +170,6 @@ def train(image_embed_dim,
|
||||
|
||||
if RESUME:
|
||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||
|
||||
# TODO, optimizer and scaler needs to be loaded as well
|
||||
|
||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||
|
||||
# diffusion prior trainer
|
||||
@@ -353,15 +353,12 @@ def main(
|
||||
}
|
||||
}
|
||||
|
||||
RESUME = False
|
||||
|
||||
# Check if DPRIOR_PATH exists(saved model path)
|
||||
|
||||
DPRIOR_PATH = args.pretrained_model_path
|
||||
RESUME = exists(DPRIOR_PATH)
|
||||
|
||||
if(DPRIOR_PATH is not None):
|
||||
RESUME = True
|
||||
else:
|
||||
if not RESUME:
|
||||
tracker.init(
|
||||
entity = wandb_entity,
|
||||
project = wandb_project,
|
||||
|
||||
Reference in New Issue
Block a user