diff --git a/train_decoder.py b/train_decoder.py index 76b49c2..c6ed801 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config): Loads the model with an appropriate method depending on the tracker """ print(print_ribbon(f"Loading model from {recall_source}")) - state_dict = tracker.recall_state_dict(recall_source, **load_config) + state_dict = tracker.recall_state_dict(recall_source, **load_config.dict()) trainer.load_state_dict(state_dict["trainer"]) print("Model loaded") return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]