Fix checkpoint loading too

This commit is contained in:
Stephan Auerhahn
2023-08-03 17:56:24 -07:00
parent 84d3a7f6f5
commit 4aea6fa2a4

View File

@@ -178,6 +178,9 @@ class SamplingPipeline:
if model_path is None:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if not os.path.exists(model_path):
# This supports development installs where checkpoints is root level of the repo
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
if config_path is None:
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"