mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
fix noisy latent handling
This commit is contained in:
@@ -226,15 +226,17 @@ class SamplingPipeline:
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
stage2strength=None,
|
||||
noise_strength=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
if stage2strength is not None:
|
||||
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||
sampler.discretization,
|
||||
strength=stage2strength,
|
||||
original_steps=params.steps,
|
||||
)
|
||||
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization,
|
||||
image_strength=None,
|
||||
noise_strength=noise_strength,
|
||||
steps=params.steps,
|
||||
)
|
||||
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
@@ -262,11 +264,15 @@ class SamplingPipeline:
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization, strength=params.img2img_strength
|
||||
sampler.discretization,
|
||||
image_strength=params.img2img_strength,
|
||||
noise_strength=noise_strength,
|
||||
steps=params.steps,
|
||||
)
|
||||
|
||||
height, width = image.shape[2], image.shape[3]
|
||||
@@ -286,12 +292,29 @@ class SamplingPipeline:
|
||||
filter=None,
|
||||
)
|
||||
|
||||
def wrap_discretization(self, discretization, strength=1.0):
|
||||
if (
|
||||
not isinstance(discretization, Img2ImgDiscretizationWrapper)
|
||||
and strength < 1.0
|
||||
def wrap_discretization(
|
||||
self, discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return Img2ImgDiscretizationWrapper(discretization, strength=strength)
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
and noise_strength < 1.0
|
||||
and noise_strength > 0.0
|
||||
and steps is not None
|
||||
):
|
||||
discretization = Txt2NoisyDiscretizationWrapper(
|
||||
discretization,
|
||||
strength=noise_strength,
|
||||
original_steps=steps,
|
||||
)
|
||||
return discretization
|
||||
|
||||
def refiner(
|
||||
@@ -300,7 +323,7 @@ class SamplingPipeline:
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
params: SamplingParams = SamplingParams(
|
||||
sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2
|
||||
sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15
|
||||
),
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
@@ -320,7 +343,7 @@ class SamplingPipeline:
|
||||
}
|
||||
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization, strength=params.img2img_strength
|
||||
sampler.discretization, image_strength=params.img2img_strength
|
||||
)
|
||||
|
||||
return do_img2img(
|
||||
@@ -332,6 +355,7 @@ class SamplingPipeline:
|
||||
skip_encode=True,
|
||||
return_latents=return_latents,
|
||||
filter=None,
|
||||
add_noise=False,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user