mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
adding some typing
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
@@ -26,9 +26,9 @@ from sgm.inference.helpers import (
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
|
||||
global lowvram_mode
|
||||
state = dict()
|
||||
state: Dict[str, Any] = dict()
|
||||
if not "model" in state:
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
@@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
return None
|
||||
|
||||
|
||||
def load_img(display=True, key=None) -> torch.Tensor:
|
||||
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user