mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-04 21:34:26 +01:00
Extract method for img2img wrapper
This commit is contained in:
@@ -223,11 +223,10 @@ class SamplingPipeline:
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
if params.img2img_strength < 1.0:
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization,
|
||||
strength=params.img2img_strength,
|
||||
)
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization, strength=params.img2img_strength
|
||||
)
|
||||
|
||||
height, width = image.shape[2], image.shape[3]
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
@@ -245,6 +244,14 @@ class SamplingPipeline:
|
||||
filter=None,
|
||||
)
|
||||
|
||||
def wrap_discretization(self, discretization, strength=1.0):
|
||||
if (
|
||||
not isinstance(discretization, Img2ImgDiscretizationWrapper)
|
||||
and strength < 1.0
|
||||
):
|
||||
return Img2ImgDiscretizationWrapper(discretization, strength=strength)
|
||||
return discretization
|
||||
|
||||
def refiner(
|
||||
self,
|
||||
image,
|
||||
@@ -270,11 +277,9 @@ class SamplingPipeline:
|
||||
"negative_aesthetic_score": 2.5,
|
||||
}
|
||||
|
||||
if params.img2img_strength < 1.0:
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization,
|
||||
strength=params.img2img_strength,
|
||||
)
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization, strength=params.img2img_strength
|
||||
)
|
||||
|
||||
return do_img2img(
|
||||
image,
|
||||
|
||||
Reference in New Issue
Block a user