This commit is contained in:
Stephan Auerhahn
2023-08-03 17:57:55 -07:00
parent 4aea6fa2a4
commit 77d0e27747

View File

@@ -180,14 +180,20 @@ class SamplingPipeline:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if not os.path.exists(model_path):
# This supports development installs where checkpoints is root level of the repo
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
model_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "checkpoints"
)
if config_path is None:
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
)
if not os.path.exists(config_path):
# This supports development installs where configs is root level of the repo
config_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
config_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "configs/inference"
)
self.config = str(config_path / self.specs.config)
self.ckpt = str(model_path / self.specs.ckpt)
if not os.path.exists(self.config):