Phil Wang
2022-06-07 09:03:48 -07:00
parent 350a3d6045
commit bee5bf3815

View File

@@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config):
Loads the model with an appropriate method depending on the tracker Loads the model with an appropriate method depending on the tracker
""" """
print(print_ribbon(f"Loading model from {recall_source}")) print(print_ribbon(f"Loading model from {recall_source}"))
state_dict = tracker.recall_state_dict(recall_source, **load_config) state_dict = tracker.recall_state_dict(recall_source, **load_config.dict())
trainer.load_state_dict(state_dict["trainer"]) trainer.load_state_dict(state_dict["trainer"])
print("Model loaded") print("Model loaded")
return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"] return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]