diff --git a/sgm/inference/api.py b/sgm/inference/api.py index be4b245..ec17dfe 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -223,11 +223,10 @@ class SamplingPipeline: ): sampler = get_sampler_config(params) - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) + height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt @@ -245,6 +244,14 @@ class SamplingPipeline: filter=None, ) + def wrap_discretization(self, discretization, strength=1.0): + if ( + not isinstance(discretization, Img2ImgDiscretizationWrapper) + and strength < 1.0 + ): + return Img2ImgDiscretizationWrapper(discretization, strength=strength) + return discretization + def refiner( self, image, @@ -270,11 +277,9 @@ class SamplingPipeline: "negative_aesthetic_score": 2.5, } - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) return do_img2img( image,