make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click

This commit is contained in:
Phil Wang
2022-05-15 10:48:10 -07:00
parent 71d0c4edae
commit aa6772dcff
2 changed files with 140 additions and 101 deletions

View File

@@ -117,7 +117,7 @@ def load_diffusion_model(dprior_path, device):
# Load state dict from saved model
diffusion_prior.load_state_dict(loaded_obj['model'])
return diffusion_prior
return diffusion_prior, loaded_obj
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
# Saving State Dict