From c0655731d5e637169f3019b349533528c02b6992 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 04:25:56 -0700 Subject: [PATCH] fix streamlit inputs --- scripts/demo/sampling.py | 10 ++++++---- scripts/demo/streamlit_helpers.py | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index a18b21a..5415551 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -88,14 +88,14 @@ def run_txt2img( model: SamplingPipeline = state["model"] params: SamplingParams = state["params"] if version.startswith("stable-diffusion-xl") and version.endswith("-base"): - params.width, params.height = st.selectbox( + width, height = st.selectbox( "Resolution:", list(SD_XL_BASE_RATIOS.values()), 10 ) else: - params.height = int( + height = int( st.number_input("H", value=spec.height, min_value=64, max_value=2048) ) - params.width = int( + width = int( st.number_input("W", value=spec.width, min_value=64, max_value=2048) ) @@ -107,6 +107,8 @@ def run_txt2img( ) params, num_rows, num_cols = init_sampling(params=params) num_samples = num_rows * num_cols + params.height = height + params.width = width if st.button("Sample"): st.write(f"**Model I:** {version}") @@ -289,8 +291,8 @@ if __name__ == "__main__": ) params2, *_ = init_sampling( - key=2, params=state2["params"], + key=2, specify_num_samples=False, ) st.write("__________________________") diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 93293dd..5c35069 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -130,7 +130,7 @@ def show_samples(samples, outputs): outputs.image(grid.cpu().numpy()) -def get_guider(key, params: SamplingParams) -> SamplingParams: +def get_guider(params: SamplingParams, key=1) -> SamplingParams: params.guider = Guider( st.sidebar.selectbox( f"Discretization #{key}", [member.value for member in Guider] @@ -157,8 +157,8 @@ def get_guider(key, params: SamplingParams) -> SamplingParams: def init_sampling( + params: SamplingParams, key=1, - params: SamplingParams = SamplingParams(), specify_num_samples=True, ) -> Tuple[SamplingParams, int, int]: params = SamplingParams(img2img_strength=params.img2img_strength)