mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
adding some typing
This commit is contained in:
@@ -6,7 +6,7 @@ import torch
|
||||
from einops import rearrange, repeat
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict, Any
|
||||
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
@@ -26,9 +26,9 @@ from sgm.inference.helpers import (
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True):
|
||||
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
|
||||
global lowvram_mode
|
||||
state = dict()
|
||||
state: Dict[str, Any] = dict()
|
||||
if not "model" in state:
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
@@ -244,9 +244,10 @@ def get_interactive_image(key=None) -> Optional[Image.Image]:
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
return None
|
||||
|
||||
|
||||
def load_img(display=True, key=None) -> torch.Tensor:
|
||||
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
|
||||
@@ -18,7 +18,7 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
@@ -363,7 +363,8 @@ class SamplingPipeline:
|
||||
)
|
||||
|
||||
|
||||
def get_guider_config(params: SamplingParams):
|
||||
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
guider_config: Dict[str, Any]
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
@@ -389,7 +390,8 @@ def get_guider_config(params: SamplingParams):
|
||||
return guider_config
|
||||
|
||||
|
||||
def get_discretization_config(params: SamplingParams):
|
||||
def get_discretization_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
discretization_config: Dict[str, Any]
|
||||
if params.discretization == Discretization.LEGACY_DDPM:
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
|
||||
Reference in New Issue
Block a user