more types

This commit is contained in:
Stephan Auerhahn
2023-08-09 13:46:06 -07:00
parent 725bea9f75
commit d245e2002f

View File

@@ -18,6 +18,7 @@ from sgm.modules.diffusionmodules.sampling import (
LinearMultistepSampler,
)
from sgm.util import load_model_from_config
import torch
from typing import Optional, Dict, Any
@@ -226,8 +227,8 @@ class SamplingPipeline:
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength=None,
filter=None,
noise_strength: Optional[float] = None,
filter: Any = None,
):
sampler = get_sampler_config(params)
@@ -260,13 +261,13 @@ class SamplingPipeline:
def image_to_image(
self,
params: SamplingParams,
image,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength=None,
filter=None,
noise_strength: Optional[float] = None,
filter: Any = None,
):
sampler = get_sampler_config(params)
@@ -321,7 +322,7 @@ class SamplingPipeline:
def refiner(
self,
image,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
params: SamplingParams = SamplingParams(
@@ -329,8 +330,8 @@ class SamplingPipeline:
),
samples: int = 1,
return_latents: bool = False,
filter=None,
add_noise=False,
filter: Any = None,
add_noise: bool = False,
):
sampler = get_sampler_config(params)
value_dict = {