Files
generative-models/scripts/demo/sampling.py
Stephan Auerhahn 931d7a389a Add inference helpers & tests (#57)
* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
2023-07-26 04:37:24 -07:00

344 lines
9.1 KiB
Python

import numpy as np
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
SAVE_PATH = "outputs/demo/txt2img/"
SD_XL_BASE_RATIOS = {
"0.5": (704, 1408),
"0.52": (704, 1344),
"0.57": (768, 1344),
"0.6": (768, 1280),
"0.68": (832, 1216),
"0.72": (832, 1152),
"0.78": (896, 1152),
"0.82": (896, 1088),
"0.88": (960, 1088),
"0.94": (960, 1024),
"1.0": (1024, 1024),
"1.07": (1024, 960),
"1.13": (1088, 960),
"1.21": (1088, 896),
"1.29": (1152, 896),
"1.38": (1152, 832),
"1.46": (1216, 832),
"1.67": (1280, 768),
"1.75": (1344, 768),
"1.91": (1344, 704),
"2.0": (1408, 704),
"2.09": (1472, 704),
"2.4": (1536, 640),
"2.5": (1600, 640),
"2.89": (1664, 576),
"3.0": (1728, 576),
}
VERSION2SPECS = {
"SD-XL base": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": False,
"config": "configs/inference/sd_xl_base.yaml",
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
"is_guided": True,
},
"sd-2.1": {
"H": 512,
"W": 512,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1.yaml",
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
"is_guided": True,
},
"sd-2.1-768": {
"H": 768,
"W": 768,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_2_1_768.yaml",
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
},
"SDXL-Refiner": {
"H": 1024,
"W": 1024,
"C": 4,
"f": 8,
"is_legacy": True,
"config": "configs/inference/sd_xl_refiner.yaml",
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
"is_guided": True,
},
}
def load_img(display=True, key=None, device="cuda"):
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
return image.to(device)
def run_txt2img(
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
):
if version == "SD-XL base":
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
W, H = SD_XL_BASE_RATIOS[ratio]
else:
H = st.sidebar.number_input(
"H", value=version_dict["H"], min_value=64, max_value=2048
)
W = st.sidebar.number_input(
"W", value=version_dict["W"], min_value=64, max_value=2048
)
C = version_dict["C"]
F = version_dict["f"]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
prompt=prompt,
negative_prompt=negative_prompt,
)
num_rows, num_cols, sampler = init_sampling(
use_identity_guider=not version_dict["is_guided"]
)
num_samples = num_rows * num_cols
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
def run_img2img(
state, version_dict, is_legacy=False, return_latents=False, filter=None
):
img = load_img()
if img is None:
return None
H, W = img.shape[2], img.shape[3]
init_dict = {
"orig_width": W,
"orig_height": H,
"target_width": W,
"target_height": H,
}
value_dict = init_embedder_options(
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
init_dict,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
)
num_rows, num_cols, sampler = init_sampling(
img2img_strength=strength,
use_identity_guider=not version_dict["is_guided"],
)
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
def apply_refiner(
input,
state,
sampler,
num_samples,
prompt,
negative_prompt,
filter=None,
):
init_dict = {
"orig_width": input.shape[3] * 8,
"orig_height": input.shape[2] * 8,
"target_width": input.shape[3] * 8,
"target_height": input.shape[2] * 8,
}
value_dict = init_dict
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
value_dict["crop_coords_top"] = 0
value_dict["crop_coords_left"] = 0
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
st.warning(f"refiner input shape: {input.shape}")
samples = do_img2img(
input,
state["model"],
sampler,
value_dict,
num_samples,
skip_encode=True,
filter=filter,
)
return samples
if __name__ == "__main__":
st.title("Stable Diffusion")
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
version_dict = VERSION2SPECS[version]
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
st.write("__________________________")
if version == "SD-XL base":
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
st.write("__________________________")
else:
add_pipeline = False
filter = DeepFloydDataFiltering(verbose=False)
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict)
model = state["model"]
is_legacy = version_dict["is_legacy"]
prompt = st.text_input(
"prompt",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
)
if is_legacy:
negative_prompt = st.text_input("negative prompt", "")
else:
negative_prompt = "" # which is unused
if add_pipeline:
st.write("__________________________")
version2 = "SDXL-Refiner"
st.warning(
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
)
st.write("**Refiner Options:**")
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2)
stage2strength = st.number_input(
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
)
sampler2 = init_sampling(
key=2,
img2img_strength=stage2strength,
use_identity_guider=not version_dict2["is_guided"],
get_num_samples=False,
)
st.write("__________________________")
if mode == "txt2img":
out = run_txt2img(
state,
version,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
)
elif mode == "img2img":
out = run_img2img(
state,
version_dict,
is_legacy=is_legacy,
return_latents=add_pipeline,
filter=filter,
)
else:
raise ValueError(f"unknown mode {mode}")
if isinstance(out, (tuple, list)):
samples, samples_z = out
else:
samples = out
samples_z = None
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
state2,
sampler2,
samples_z.shape[0],
prompt=prompt,
negative_prompt=negative_prompt if is_legacy else "",
filter=filter,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)