From 5fde7e73b80ff781bb753d1fa2bf6c0b7d1c8c45 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:35:36 -0700 Subject: [PATCH] set a default scale --- sgm/inference/api.py | 26 +++++++------------------- 1 file changed, 7 insertions(+), 19 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index d2d5a7d..51269b6 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -68,7 +68,7 @@ class SamplingParams: discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA thresholder: Thresholder = Thresholder.NONE - scale: float + scale: float = 5.0 aesthetic_score: float = 6.0 negative_aesthetic_score: float = 2.5 img2img_strength: float = 1.0 @@ -185,17 +185,13 @@ model_specs = { } -def wrap_discretization( - discretization, image_strength=None, noise_strength=None, steps=None -): +def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None): if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( discretization, Txt2NoisyDiscretizationWrapper ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper( - discretization, strength=image_strength - ) + discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) if ( noise_strength is not None @@ -249,19 +245,13 @@ class SamplingPipeline: self.config = os.path.join(config_path, "inference", self.specs.config) self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError( - f"Config {self.config} not found, check model spec or config_path" - ) + raise ValueError(f"Config {self.config} not found, check model spec or config_path") if not os.path.exists(self.ckpt): - raise ValueError( - f"Checkpoint {self.ckpt} not found, check model spec or config_path" - ) + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device_manager = get_model_manager(device) - self.model = self._load_model( - device_manager=self.device_manager, use_fp16=use_fp16 - ) + self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) @@ -406,9 +396,7 @@ class SamplingPipeline: def get_guider_config(params: SamplingParams) -> Dict[str, Any]: guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: - guider_config = { - "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" - } + guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} elif params.guider == Guider.VANILLA: scale = params.scale