mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
Add inference helpers & tests (#57)
* Add inference helpers & tests * Support testing with hatch * fixes to hatch script * add inference test action * change workflow trigger * widen trigger to test * revert changes to workflow triggers * Install local python in action * Trigger on push again * fix python version * add CODEOWNERS and change triggers * Report tests results * update action versions * format * Fix typo and add refiner helper * use a shared path loaded from a secret for checkpoints source * typo fix * Use device from input and remove duplicated code * PR feedback * fix call to load_model_from_config * Move model to gpu * Refactor helpers * cleanup * test refiner, prep for 1.0, align with metadata * fix paths on second load * deduplicate streamlit code * filenames * fixes * add pydantic to requirements * fix usage of `msg` in demo script * remove double text * run black * fix streamlit sampling when returning latents * extract function for streamlit output * another fix for streamlit outputs * fix img2img in streamlit * Make fp16 optional and fix device param * PR feedback * fix dict cast for dataclass * run black, update ci script * cache pip dependencies on hosted runners, remove extra runs * install package in ci env * fix cache path * PR cleanup * one more cleanup * don't cache, it filled up
This commit is contained in:
@@ -1,7 +1,14 @@
|
||||
import numpy as np
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.inference.helpers import (
|
||||
do_img2img,
|
||||
do_sample,
|
||||
get_unique_embedder_keys_from_conditioner,
|
||||
perform_save_locally,
|
||||
)
|
||||
|
||||
SAVE_PATH = "outputs/demo/txt2img/"
|
||||
|
||||
@@ -131,6 +138,8 @@ def run_txt2img(
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
@@ -144,6 +153,8 @@ def run_txt2img(
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
show_samples(out, outputs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@@ -175,6 +186,8 @@ def run_img2img(
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
@@ -185,6 +198,7 @@ def run_img2img(
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
show_samples(out, outputs)
|
||||
return out
|
||||
|
||||
|
||||
@@ -249,8 +263,6 @@ if __name__ == "__main__":
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
@@ -275,7 +287,6 @@ if __name__ == "__main__":
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2)
|
||||
st.info(state2["msg"])
|
||||
|
||||
stage2strength = st.number_input(
|
||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
||||
@@ -315,6 +326,7 @@ if __name__ == "__main__":
|
||||
samples_z = None
|
||||
|
||||
if add_pipeline and samples_z is not None:
|
||||
outputs = st.empty()
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
@@ -325,6 +337,7 @@ if __name__ == "__main__":
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=filter,
|
||||
)
|
||||
show_samples(samples, outputs)
|
||||
|
||||
if save_locally and samples is not None:
|
||||
perform_save_locally(save_path, samples)
|
||||
|
||||
Reference in New Issue
Block a user