From 73287ec3a30109020e57996612399c99ca759b39 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 23:42:11 +0000 Subject: [PATCH] Extract method for img2img wrapper --- sgm/inference/api.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index be4b245..ec17dfe 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -223,11 +223,10 @@ class SamplingPipeline: ): sampler = get_sampler_config(params) - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) + height, width = image.shape[2], image.shape[3] value_dict = asdict(params) value_dict["prompt"] = prompt @@ -245,6 +244,14 @@ class SamplingPipeline: filter=None, ) + def wrap_discretization(self, discretization, strength=1.0): + if ( + not isinstance(discretization, Img2ImgDiscretizationWrapper) + and strength < 1.0 + ): + return Img2ImgDiscretizationWrapper(discretization, strength=strength) + return discretization + def refiner( self, image, @@ -270,11 +277,9 @@ class SamplingPipeline: "negative_aesthetic_score": 2.5, } - if params.img2img_strength < 1.0: - sampler.discretization = Img2ImgDiscretizationWrapper( - sampler.discretization, - strength=params.img2img_strength, - ) + sampler.discretization = self.wrap_discretization( + sampler.discretization, strength=params.img2img_strength + ) return do_img2img( image,