adding some typing

This commit is contained in:
Stephan Auerhahn
2023-08-09 13:27:30 -07:00
parent f86ffac274
commit a009aa8a9f
2 changed files with 10 additions and 7 deletions

View File

@@ -6,7 +6,7 @@ import torch
from einops import rearrange, repeat
from PIL import Image
from torchvision import transforms
from typing import Optional, Tuple
from typing import Optional, Tuple, Dict, Any
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
@@ -26,9 +26,9 @@ from sgm.inference.helpers import (
@st.cache_resource()
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
global lowvram_mode
state = dict()
state: Dict[str, Any] = dict()
if not "model" in state:
config = spec.config
ckpt = spec.ckpt
@@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]:
if not image.mode == "RGB":
image = image.convert("RGB")
return image
return None
def load_img(display=True, key=None) -> torch.Tensor:
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
image = get_interactive_image(key=key)
if image is None:
return None