diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index a5f1b03..a0f3848 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -132,7 +132,9 @@ def show_samples(samples, outputs): def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( - st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider]) + st.sidebar.selectbox( + f"Discretization #{key}", [member.value for member in Guider] + ) ) if params.guider == Guider.VANILLA: @@ -161,10 +163,14 @@ def init_sampling( ) -> Tuple[SamplingParams, int, int]: num_rows, num_cols = 1, 1 if specify_num_samples: - num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10) + num_cols = st.number_input( + f"num cols #{key}", value=2, min_value=1, max_value=10 + ) params.steps = int( - st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000) + st.sidebar.number_input( + f"steps #{key}", value=params.steps, min_value=1, max_value=1000 + ) ) params.sampler = Sampler( @@ -212,11 +218,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams: ) elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL): - params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0) + params.s_noise = st.sidebar.number_input( + "s_noise", value=params.s_noise, min_value=0.0 + ) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: - params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1)) + params.order = int( + st.sidebar.number_input("order", value=params.order, min_value=1) + ) return params diff --git a/sgm/inference/api.py b/sgm/inference/api.py index d863f5e..87592dc 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -61,9 +61,9 @@ class SamplingParams: Parameters for sampling. """ - width: int - height: int - steps: int + width: Optional[int] = None + height: Optional[int] = None + steps: Optional[int] = None sampler: Sampler = Sampler.EULER_EDM discretization: Discretization = Discretization.LEGACY_DDPM guider: Guider = Guider.VANILLA @@ -286,6 +286,15 @@ class SamplingPipeline: ): if params is None: params = self.specs.default_params + else: + # Set defaults if optional params are not specified + if params.width is None: + params.width = self.specs.default_params.width + if params.height is None: + params.height = self.specs.default_params.height + if params.steps is None: + params.steps = self.specs.default_params.steps + sampler = get_sampler_config(params) sampler.discretization = wrap_discretization(