From 76ca428422e99e744e2308e357c3aff6ed5ae2b2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sun, 6 Aug 2023 21:39:18 +0000 Subject: [PATCH] fix path resolution bug --- sgm/inference/api.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 1ab892f..fd89558 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -181,30 +181,20 @@ class SamplingPipeline: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "checkpoints" - ) + model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" if config_path is None: - config_path = ( - pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" - ) + config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() - / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" ) - self.config = str(config_path / self.specs.config) - self.ckpt = str(model_path / self.specs.ckpt) + self.config = str(pathlib.Path(config_path) / self.specs.config) + self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError( - f"Config {self.config} not found, check model spec or config_path" - ) + raise ValueError(f"Config {self.config} not found, check model spec or config_path") if not os.path.exists(self.ckpt): - raise ValueError( - f"Checkpoint {self.ckpt} not found, check model spec or config_path" - ) + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -300,9 +290,7 @@ class SamplingPipeline: ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) + discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) if ( noise_strength is not None @@ -361,9 +349,7 @@ class SamplingPipeline: def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif params.guider == Guider.VANILLA: scale = params.scale