diff --git a/sgm/inference/api.py b/sgm/inference/api.py index fd89558..89b7370 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -62,9 +62,9 @@ class SamplingParams: discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float = 6.0 - aesthetic_score: float = 5.0 - negative_aesthetic_score: float = 5.0 + scale: float = 5.0 + aesthetic_score: float = 6.0 + negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 orig_width: int = 1024 orig_height: int = 1024 @@ -181,20 +181,30 @@ class SamplingPipeline: model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" if not os.path.exists(model_path): # This supports development installs where checkpoints is root level of the repo - model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints" + model_path = ( + pathlib.Path(__file__).parent.parent.parent.resolve() + / "checkpoints" + ) if config_path is None: - config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + config_path = ( + pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + ) if not os.path.exists(config_path): # This supports development installs where configs is root level of the repo config_path = ( - pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference" + pathlib.Path(__file__).parent.parent.parent.resolve() + / "configs/inference" ) self.config = str(pathlib.Path(config_path) / self.specs.config) self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16) @@ -290,7 +300,9 @@ class SamplingPipeline: ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) if ( noise_strength is not None @@ -349,7 +361,9 @@ class SamplingPipeline: def get_guider_config(params: SamplingParams): if params.guider == Guider.IDENTITY: - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif params.guider == Guider.VANILLA: scale = params.scale