mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
align with streamlit helpers and re-de-deuplicate
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from sgm.inference.helpers import (
|
||||
do_img2img,
|
||||
do_sample,
|
||||
get_unique_embedder_keys_from_conditioner,
|
||||
perform_save_locally,
|
||||
)
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
|
||||
SAVE_PATH = "outputs/demo/txt2img/"
|
||||
@@ -99,9 +105,7 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
@@ -143,6 +147,8 @@ def run_txt2img(
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
@@ -156,6 +162,9 @@ def run_txt2img(
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -184,9 +193,7 @@ def run_img2img(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
stage2strength=stage2strength,
|
||||
@@ -194,6 +201,8 @@ def run_img2img(
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
@@ -204,6 +213,7 @@ def run_img2img(
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
show_samples(out, outputs)
|
||||
return out
|
||||
|
||||
|
||||
@@ -342,6 +352,7 @@ if __name__ == "__main__":
|
||||
samples_z = None
|
||||
|
||||
if add_pipeline and samples_z is not None:
|
||||
outputs = st.empty()
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
@@ -353,6 +364,7 @@ if __name__ == "__main__":
|
||||
filter=state.get("filter"),
|
||||
finish_denoising=finish_denoising,
|
||||
)
|
||||
show_samples(samples, outputs)
|
||||
|
||||
if save_locally and samples is not None:
|
||||
perform_save_locally(save_path, samples)
|
||||
|
||||
Reference in New Issue
Block a user