From a009aa8a9f58918beff28cbd62bcdb1615986c1a Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:27:30 -0700 Subject: [PATCH] adding some typing --- scripts/demo/streamlit_helpers.py | 9 +++++---- sgm/inference/api.py | 8 +++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 70b7d06..c25284f 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -6,7 +6,7 @@ import torch from einops import rearrange, repeat from PIL import Image from torchvision import transforms -from typing import Optional, Tuple +from typing import Optional, Tuple, Dict, Any from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering @@ -26,9 +26,9 @@ from sgm.inference.helpers import ( @st.cache_resource() -def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True): +def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: global lowvram_mode - state = dict() + state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt @@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]: if not image.mode == "RGB": image = image.convert("RGB") return image + return None -def load_img(display=True, key=None) -> torch.Tensor: +def load_img(display=True, key=None) -> Optional[torch.Tensor]: image = get_interactive_image(key=key) if image is None: return None diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ad6aecc..668cc65 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,7 +18,7 @@ from sgm.modules.diffusionmodules.sampling import ( LinearMultistepSampler, ) from sgm.util import load_model_from_config -from typing import Optional +from typing import Optional, Dict, Any class ModelArchitecture(str, Enum): @@ -363,7 +363,8 @@ class SamplingPipeline: ) -def get_guider_config(params: SamplingParams): +def get_guider_config(params: SamplingParams) -> Dict[str, Any]: + guider_config: Dict[str, Any] if params.guider == Guider.IDENTITY: guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" @@ -389,7 +390,8 @@ def get_guider_config(params: SamplingParams): return guider_config -def get_discretization_config(params: SamplingParams): +def get_discretization_config(params: SamplingParams) -> Dict[str, Any]: + discretization_config: Dict[str, Any] if params.discretization == Discretization.LEGACY_DDPM: discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",