diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 2b5500c..4f7c78b 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -180,14 +180,20 @@ class SamplingPipeline: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" + model_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "checkpoints" + ) if config_path is None: config_path = ( pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" ) if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo - config_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" + ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config):