mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 09:44:19 +01:00
remove todo
This commit is contained in:
@@ -24,6 +24,9 @@ tracker = WandbTracker()
|
|||||||
|
|
||||||
# helpers functions
|
# helpers functions
|
||||||
|
|
||||||
|
def exists(val):
|
||||||
|
val is not None
|
||||||
|
|
||||||
class Timer:
|
class Timer:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -167,9 +170,6 @@ def train(image_embed_dim,
|
|||||||
|
|
||||||
if RESUME:
|
if RESUME:
|
||||||
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
diffusion_prior, loaded_obj = load_diffusion_model(DPRIOR_PATH, device)
|
||||||
|
|
||||||
# TODO, optimizer and scaler needs to be loaded as well
|
|
||||||
|
|
||||||
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
tracker.init(entity = wandb_entity, project = wandb_project, config = config)
|
||||||
|
|
||||||
# diffusion prior trainer
|
# diffusion prior trainer
|
||||||
@@ -353,15 +353,12 @@ def main(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RESUME = False
|
|
||||||
|
|
||||||
# Check if DPRIOR_PATH exists(saved model path)
|
# Check if DPRIOR_PATH exists(saved model path)
|
||||||
|
|
||||||
DPRIOR_PATH = args.pretrained_model_path
|
DPRIOR_PATH = args.pretrained_model_path
|
||||||
|
RESUME = exists(DPRIOR_PATH)
|
||||||
|
|
||||||
if(DPRIOR_PATH is not None):
|
if not RESUME:
|
||||||
RESUME = True
|
|
||||||
else:
|
|
||||||
tracker.init(
|
tracker.init(
|
||||||
entity = wandb_entity,
|
entity = wandb_entity,
|
||||||
project = wandb_project,
|
project = wandb_project,
|
||||||
|
|||||||
Reference in New Issue
Block a user