From d245e2002fa6b2b0eb6826a954d738a6481c9505 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Wed, 9 Aug 2023 13:46:06 -0700 Subject: [PATCH] more types --- sgm/inference/api.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 668cc65..e3f3d17 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -18,6 +18,7 @@ from sgm.modules.diffusionmodules.sampling import ( LinearMultistepSampler, ) from sgm.util import load_model_from_config +import torch from typing import Optional, Dict, Any @@ -226,8 +227,8 @@ class SamplingPipeline: negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -260,13 +261,13 @@ class SamplingPipeline: def image_to_image( self, params: SamplingParams, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - noise_strength=None, - filter=None, + noise_strength: Optional[float] = None, + filter: Any = None, ): sampler = get_sampler_config(params) @@ -321,7 +322,7 @@ class SamplingPipeline: def refiner( self, - image, + image: torch.Tensor, prompt: str, negative_prompt: str = "", params: SamplingParams = SamplingParams( @@ -329,8 +330,8 @@ class SamplingPipeline: ), samples: int = 1, return_latents: bool = False, - filter=None, - add_noise=False, + filter: Any = None, + add_noise: bool = False, ): sampler = get_sampler_config(params) value_dict = {