run black

This commit is contained in:
Stephan Auerhahn
2023-08-12 05:40:25 -07:00
parent 5fde7e73b8
commit 65c6ec1cec
3 changed files with 51 additions and 17 deletions

View File

@@ -185,13 +185,17 @@ model_specs = {
}
def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None):
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength)
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
@@ -245,13 +249,19 @@ class SamplingPipeline:
self.config = os.path.join(config_path, "inference", self.specs.config)
self.ckpt = os.path.join(model_path, self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = get_model_manager(device)
self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
@@ -396,7 +406,9 @@ class SamplingPipeline:
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
guider_config: Dict[str, Any]
if params.guider == Guider.IDENTITY:
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale