From 65c6ec1cecd3ce6fb47ce714ec071fc4af3c94f2 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 05:40:25 -0700 Subject: [PATCH] run black --- scripts/demo/sampling.py | 24 ++++++++++++++++++------ scripts/demo/streamlit_helpers.py | 20 +++++++++++++++----- sgm/inference/api.py | 24 ++++++++++++++++++------ 3 files changed, 51 insertions(+), 17 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 20a8f03..017db21 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -66,7 +66,9 @@ def load_img(display=True, key=None, device="cuda"): st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image = np.array(image.convert("RGB")) image = image[None].transpose(0, 3, 1, 2) @@ -85,10 +87,16 @@ def run_txt2img( model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] if model_id in sdxl_base_model_list: - width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10) + width, height = st.selectbox( + "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 + ) else: - height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048)) - width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048)) + height = int( + st.number_input("H", value=params.height, min_value=64, max_value=2048) + ) + width = int( + st.number_input("W", value=params.width, min_value=64, max_value=2048) + ) params = init_embedder_options( get_unique_embedder_keys_from_conditioner(model.model.conditioner), @@ -230,7 +238,9 @@ if __name__ == "__main__": else: add_pipeline = False - seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))) + seed = int( + st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)) + ) seed_everything(seed) save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version))) @@ -270,7 +280,9 @@ if __name__ == "__main__": st.write("**Refiner Options:**") specs2 = model_specs[version2] - state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap) + state2 = init_st( + specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap + ) params2 = state2["params"] params2.img2img_strength = st.number_input( diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index a1770a8..eeb0f20 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -132,7 +132,9 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) if params.guider == Guider.VANILLA: @@ -163,10 +165,14 @@ def init_sampling( num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) + num_cols = st.number_input( + f"num cols #{key}", value=2, min_value=1, max_value=10 + ) params.steps = int( - st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) params.sampler = Sampler( @@ -214,11 +220,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 + ) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 51269b6..d863f5e 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -185,13 +185,17 @@ model_specs = { } -def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None): +def wrap_discretization( + discretization, image_strength=None, noise_strength=None, steps=None +): if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( discretization, Txt2NoisyDiscretizationWrapper ): return discretization # Already wrapped if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: - discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength) + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) if ( noise_strength is not None @@ -245,13 +249,19 @@ class SamplingPipeline: self.config = os.path.join(config_path, "inference", self.specs.config) self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): - raise ValueError(f"Config {self.config} not found, check model spec or config_path") + raise ValueError( + f"Config {self.config} not found, check model spec or config_path" + ) if not os.path.exists(self.ckpt): - raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") + raise ValueError( + f"Checkpoint {self.ckpt} not found, check model spec or config_path" + ) self.device_manager = get_model_manager(device) - self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16) + self.model = self._load_model( + device_manager=self.device_manager, use_fp16=use_fp16 + ) def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) @@ -396,7 +406,9 @@ class SamplingPipeline: def get_guider_config(params: SamplingParams) -> Dict[str, Any]: guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: - guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"} + guider_config = { + "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" + } elif params.guider == Guider.VANILLA: scale = params.scale