align with streamlit helpers and re-de-deuplicate

This commit is contained in:
Stephan Auerhahn
2023-08-06 11:20:22 +00:00
parent 77d0e27747
commit b216934b7e
4 changed files with 140 additions and 468 deletions

View File

@@ -1,5 +1,11 @@
from pytorch_lightning import seed_everything
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import *
SAVE_PATH = "outputs/demo/txt2img/"
@@ -99,9 +105,7 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -143,6 +147,8 @@ def run_txt2img(
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
@@ -156,6 +162,9 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -184,9 +193,7 @@ def run_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
@@ -194,6 +201,8 @@ def run_img2img(
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
@@ -204,6 +213,7 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -342,6 +352,7 @@ if __name__ == "__main__":
samples_z = None
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
@@ -353,6 +364,7 @@ if __name__ == "__main__":
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)