Add inference helpers & tests (#57)

* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
This commit is contained in:
Stephan Auerhahn
2023-07-26 04:37:24 -07:00
committed by GitHub
parent e596332148
commit 931d7a389a
11 changed files with 889 additions and 346 deletions

View File

@@ -1,7 +1,14 @@
import numpy as np
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
SAVE_PATH = "outputs/demo/txt2img/"
@@ -131,6 +138,8 @@ def run_txt2img(
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
@@ -144,6 +153,8 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -175,6 +186,8 @@ def run_img2img(
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
@@ -185,6 +198,7 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -249,8 +263,6 @@ if __name__ == "__main__":
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
is_legacy = version_dict["is_legacy"]
@@ -275,7 +287,6 @@ if __name__ == "__main__":
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
st.info(state2["msg"])
stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
@@ -315,6 +326,7 @@ if __name__ == "__main__":
samples_z = None
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
@@ -325,6 +337,7 @@ if __name__ == "__main__":
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)