Easier default params

This commit is contained in:
Stephan Auerhahn
2023-08-12 13:22:04 -07:00
parent e32972b85b
commit 2fc4680bf9
2 changed files with 27 additions and 8 deletions

View File

@@ -132,7 +132,9 @@ def show_samples(samples, outputs):
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
)
if params.guider == Guider.VANILLA:
@@ -161,10 +163,14 @@ def init_sampling(
) -> Tuple[SamplingParams, int, int]:
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
params.steps = int(
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
)
params.sampler = Sampler(
@@ -212,11 +218,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
)
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params

View File

@@ -61,9 +61,9 @@ class SamplingParams:
Parameters for sampling.
"""
width: int
height: int
steps: int
width: Optional[int] = None
height: Optional[int] = None
steps: Optional[int] = None
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
@@ -286,6 +286,15 @@ class SamplingPipeline:
):
if params is None:
params = self.specs.default_params
else:
# Set defaults if optional params are not specified
if params.width is None:
params.width = self.specs.default_params.width
if params.height is None:
params.height = self.specs.default_params.height
if params.steps is None:
params.steps = self.specs.default_params.steps
sampler = get_sampler_config(params)
sampler.discretization = wrap_discretization(