From bee5bf38159cf44162086f4db71daf505a55085c Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Tue, 7 Jun 2022 09:03:48 -0700 Subject: [PATCH] fix for https://github.com/lucidrains/DALLE2-pytorch/issues/143 --- train_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_decoder.py b/train_decoder.py index 76b49c2..c6ed801 100644 --- a/train_decoder.py +++ b/train_decoder.py @@ -211,7 +211,7 @@ def recall_trainer(tracker, trainer, recall_source=None, **load_config): Loads the model with an appropriate method depending on the tracker """ print(print_ribbon(f"Loading model from {recall_source}")) - state_dict = tracker.recall_state_dict(recall_source, **load_config) + state_dict = tracker.recall_state_dict(recall_source, **load_config.dict()) trainer.load_state_dict(state_dict["trainer"]) print("Model loaded") return state_dict["epoch"], state_dict["step"], state_dict["validation_losses"]