mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 12:24:27 +01:00
Easier default params
This commit is contained in:
@@ -132,7 +132,9 @@ def show_samples(samples, outputs):
|
||||
|
||||
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
)
|
||||
)
|
||||
|
||||
if params.guider == Guider.VANILLA:
|
||||
@@ -161,10 +163,14 @@ def init_sampling(
|
||||
) -> Tuple[SamplingParams, int, int]:
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
params.steps = int(
|
||||
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
|
||||
st.sidebar.number_input(
|
||||
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
||||
)
|
||||
)
|
||||
|
||||
params.sampler = Sampler(
|
||||
@@ -212,11 +218,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
)
|
||||
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
||||
|
||||
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
|
||||
params.order = int(
|
||||
st.sidebar.number_input("order", value=params.order, min_value=1)
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
|
||||
@@ -61,9 +61,9 @@ class SamplingParams:
|
||||
Parameters for sampling.
|
||||
"""
|
||||
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
steps: Optional[int] = None
|
||||
sampler: Sampler = Sampler.EULER_EDM
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
@@ -286,6 +286,15 @@ class SamplingPipeline:
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
else:
|
||||
# Set defaults if optional params are not specified
|
||||
if params.width is None:
|
||||
params.width = self.specs.default_params.width
|
||||
if params.height is None:
|
||||
params.height = self.specs.default_params.height
|
||||
if params.steps is None:
|
||||
params.steps = self.specs.default_params.steps
|
||||
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
sampler.discretization = wrap_discretization(
|
||||
|
||||
Reference in New Issue
Block a user