mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
Add defaults to refiner function
This commit is contained in:
@@ -55,7 +55,7 @@ class Thresholder(str, Enum):
|
||||
class SamplingParams:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 50
|
||||
steps: int = 40
|
||||
sampler: Sampler = Sampler.DPMPP2M
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
@@ -247,10 +247,12 @@ class SamplingPipeline:
|
||||
|
||||
def refiner(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str = "",
|
||||
params: SamplingParams = SamplingParams(
|
||||
sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.2
|
||||
),
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
):
|
||||
|
||||
@@ -102,10 +102,10 @@ class TestInference:
|
||||
samples, samples_z = output
|
||||
assert samples is not None
|
||||
assert samples_z is not None
|
||||
refiner_pipeline.refiner(
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=10),
|
||||
refiner_pipeline.refiner(
|
||||
image=samples_z,
|
||||
prompt="A professional photograph of an astronaut riding a pig",
|
||||
params=SamplingParams(sampler=sampler_enum.value, steps=40, img2img_strength=0.20),
|
||||
negative_prompt="",
|
||||
samples=1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user