diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 182e8cf..f5ce36f 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -165,7 +165,7 @@ class SamplingPipeline: config_path=None, device="cuda", use_fp16=True, - ) -> None: + ) -> None: self.model_id = model_id if model_spec is not None: self.specs = model_spec @@ -179,13 +179,19 @@ class SamplingPipeline: if model_path is None: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if config_path is None: - config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + ) self.config = str(config_path / self.specs.config) self.ckpt = str(model_path / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16)