adding some typing

This commit is contained in:
Stephan Auerhahn
2023-08-09 13:27:30 -07:00
parent f86ffac274
commit a009aa8a9f
2 changed files with 10 additions and 7 deletions

View File

@@ -6,7 +6,7 @@ import torch
from einops import rearrange, repeat
from PIL import Image
from torchvision import transforms
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, Any
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
@@ -26,9 +26,9 @@ from sgm.inference.helpers import (
@st.cache_resource()
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
global lowvram_mode
state = dict()
state: Dict[str, Any] = dict()
if not "model" in state:
config = spec.config
ckpt = spec.ckpt
@@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]:
if not image.mode == "RGB":
image = image.convert("RGB")
return image
return None
def load_img(display=True, key=None) -> torch.Tensor:
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
image = get_interactive_image(key=key)
if image is None:
return None

View File

@@ -18,7 +18,7 @@ from sgm.modules.diffusionmodules.sampling import (
LinearMultistepSampler,
)
from sgm.util import load_model_from_config
from typing import Optional
from typing import Optional, Dict, Any
class ModelArchitecture(str, Enum):
@@ -363,7 +363,8 @@ class SamplingPipeline:
)
def get_guider_config(params: SamplingParams):
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
guider_config: Dict[str, Any]
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
@@ -389,7 +390,8 @@ def get_guider_config(params: SamplingParams):
return guider_config
def get_discretization_config(params: SamplingParams):
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
discretization_config: Dict[str, Any]
if params.discretization == Discretization.LEGACY_DDPM:
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",