diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 77f5667..1ab892f 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -226,15 +226,17 @@ class SamplingPipeline: negative_prompt: str = "", samples: int = 1, return_latents: bool = False, - stage2strength=None, + noise_strength=None, ): sampler = get_sampler_config(params) - if stage2strength is not None: - sampler.discretization = Txt2NoisyDiscretizationWrapper( - sampler.discretization, - strength=stage2strength, - original_steps=params.steps, - ) + + sampler.discretization = self.wrap_discretization( + sampler.discretization, + image_strength=None, + noise_strength=noise_strength, + steps=params.steps, + ) + value_dict = asdict(params) value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt @@ -262,11 +264,15 @@ class SamplingPipeline: negative_prompt: str = "", samples: int = 1, return_latents: bool = False, + noise_strength=None, ): sampler = get_sampler_config(params) sampler.discretization = self.wrap_discretization( - sampler.discretization, strength=params.img2img_strength + sampler.discretization, + image_strength=params.img2img_strength, + noise_strength=noise_strength, + steps=params.steps, ) height, width = image.shape[2], image.shape[3] @@ -286,12 +292,29 @@ class SamplingPipeline: filter=None, ) - def wrap_discretization(self, discretization, strength=1.0): - if ( - not isinstance(discretization, Img2ImgDiscretizationWrapper) - and strength < 1.0 + def wrap_discretization( + self, discretization, image_strength=None, noise_strength=None, steps=None + ): + if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance( + discretization, Txt2NoisyDiscretizationWrapper ): - return Img2ImgDiscretizationWrapper(discretization, strength=strength) + return discretization # Already wrapped + if image_strength is not None and image_strength < 1.0 and image_strength > 0.0: + discretization = Img2ImgDiscretizationWrapper( + discretization, strength=image_strength + ) + + if ( + noise_strength is not None + and noise_strength < 1.0 + and noise_strength > 0.0 + and steps is not None + ): + discretization = Txt2NoisyDiscretizationWrapper( + discretization, + strength=noise_strength, + original_steps=steps, + ) return discretization def refiner( @@ -300,7 +323,7 @@ class SamplingPipeline: prompt: str, negative_prompt: str = "", params: SamplingParams = SamplingParams( - sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2 + sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15 ), samples: int = 1, return_latents: bool = False, @@ -320,7 +343,7 @@ class SamplingPipeline: } sampler.discretization = self.wrap_discretization( - sampler.discretization, strength=params.img2img_strength + sampler.discretization, image_strength=params.img2img_strength ) return do_img2img( @@ -332,6 +355,7 @@ class SamplingPipeline: skip_encode=True, return_latents=return_latents, filter=None, + add_noise=False, ) diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index ae6f355..617e408 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -68,9 +68,7 @@ class TestInference: assert output is not None @pytest.mark.parametrize("sampler_enum", Sampler) - @pytest.mark.parametrize( - "use_init_image", [True, False], ids=["img2img", "txt2img"] - ) + @pytest.mark.parametrize("use_init_image", [True, False], ids=["img2img", "txt2img"]) def test_sdxl_with_refiner( self, sdxl_pipelines: Tuple[SamplingPipeline, SamplingPipeline], @@ -81,13 +79,12 @@ class TestInference: if use_init_image: output = base_pipeline.image_to_image( params=SamplingParams(sampler=sampler_enum.value, steps=10), - image=self.create_init_image( - base_pipeline.specs.height, base_pipeline.specs.width - ), + image=self.create_init_image(base_pipeline.specs.height, base_pipeline.specs.width), prompt="A professional photograph of an astronaut riding a pig", negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) else: output = base_pipeline.text_to_image( @@ -96,16 +93,17 @@ class TestInference: negative_prompt="", samples=1, return_latents=True, + noise_strength=0.15, ) assert isinstance(output, (tuple, list)) samples, samples_z = output assert samples is not None assert samples_z is not None - refiner_pipeline.refiner( + refiner_pipeline.refiner( image=samples_z, prompt="A professional photograph of an astronaut riding a pig", - params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.20), + params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.15), negative_prompt="", samples=1, )