diff --git a/sgm/inference/api.py b/sgm/inference/api.py index f5ce36f..f81e790 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -185,6 +185,11 @@ class SamplingPipeline: self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): + # This supports development installs where configs is root level of the repo + if config_path is None: + config_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + ) raise ValueError( f"Config {self.config} not found, check model spec or config_path" )