From 853adb402252d2162adffe806381e9636d93d8e5 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 12:50:23 -0700 Subject: [PATCH] Add defaults to refiner function --- sgm/inference/api.py | 8 +++++--- tests/inference/test_inference.py | 4 ++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 64c1b02..be4b245 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -55,7 +55,7 @@ class Thresholder(str, Enum): class SamplingParams: width: int = 1024 height: int = 1024 - steps: int = 50 + steps: int = 40 sampler: Sampler = Sampler.DPMPP2M discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA @@ -247,10 +247,12 @@ class SamplingPipeline: def refiner( self, - params: SamplingParams, image, prompt: str, - negative_prompt: Optional[str] = None, + negative_prompt: str = "", + params: SamplingParams = SamplingParams( + sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2 + ), samples: int = 1, return_latents: bool = False, ): diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 2b2af11..ae6f355 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -102,10 +102,10 @@ class TestInference: samples, samples_z = output assert samples is not None assert samples_z is not None - refiner_pipeline.refiner( - params=SamplingParams(sampler=sampler_enum.value, steps=10), + refiner_pipeline.refiner( image=samples_z, prompt="A professional photograph of an astronaut riding a pig", + params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.20), negative_prompt="", samples=1, )