mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-18 20:24:26 +01:00
fix streamlit inputs
This commit is contained in:
@@ -88,14 +88,14 @@ def run_txt2img(
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
|
||||
params.width, params.height = st.selectbox(
|
||||
width, height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
else:
|
||||
params.height = int(
|
||||
height = int(
|
||||
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
|
||||
)
|
||||
params.width = int(
|
||||
width = int(
|
||||
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
|
||||
)
|
||||
|
||||
@@ -107,6 +107,8 @@ def run_txt2img(
|
||||
)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
params.height = height
|
||||
params.width = width
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
@@ -289,8 +291,8 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
params2, *_ = init_sampling(
|
||||
key=2,
|
||||
params=state2["params"],
|
||||
key=2,
|
||||
specify_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
|
||||
@@ -130,7 +130,7 @@ def show_samples(samples, outputs):
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
|
||||
def get_guider(key, params: SamplingParams) -> SamplingParams:
|
||||
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
@@ -157,8 +157,8 @@ def get_guider(key, params: SamplingParams) -> SamplingParams:
|
||||
|
||||
|
||||
def init_sampling(
|
||||
params: SamplingParams,
|
||||
key=1,
|
||||
params: SamplingParams = SamplingParams(),
|
||||
specify_num_samples=True,
|
||||
) -> Tuple[SamplingParams, int, int]:
|
||||
params = SamplingParams(img2img_strength=params.img2img_strength)
|
||||
|
||||
Reference in New Issue
Block a user