mirror of
https://github.com/lucidrains/DALLE2-pytorch.git
synced 2025-12-19 17:54:20 +01:00
make sure optimizer and scaler is reloaded on resume for training diffusion prior script, move argparse to click
This commit is contained in:
@@ -117,7 +117,7 @@ def load_diffusion_model(dprior_path, device):
|
||||
# Load state dict from saved model
|
||||
diffusion_prior.load_state_dict(loaded_obj['model'])
|
||||
|
||||
return diffusion_prior
|
||||
return diffusion_prior, loaded_obj
|
||||
|
||||
def save_diffusion_model(save_path, model, optimizer, scaler, config, image_embed_dim):
|
||||
# Saving State Dict
|
||||
|
||||
Reference in New Issue
Block a user