mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
Streamlit refactor (#105)
* initial streamlit refactoring pass * cleanup and fixes * fix refiner strength * Modify params correctly * fix exception
This commit is contained in:
@@ -1,8 +1,13 @@
|
||||
from dataclasses import asdict
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from sgm.inference.api import (
|
||||
SamplingParams,
|
||||
ModelArchitecture,
|
||||
SamplingPipeline,
|
||||
model_specs,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
do_img2img,
|
||||
do_sample,
|
||||
get_unique_embedder_keys_from_conditioner,
|
||||
perform_save_locally,
|
||||
)
|
||||
@@ -39,63 +44,6 @@ SD_XL_BASE_RATIOS = {
|
||||
"3.0": (1728, 576),
|
||||
}
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SDXL-base-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_1.0.safetensors",
|
||||
},
|
||||
"SDXL-base-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||
},
|
||||
"SD-2.1": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1.yaml",
|
||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||
},
|
||||
"SD-2.1-768": {
|
||||
"H": 768,
|
||||
"W": 768,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1_768.yaml",
|
||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||
},
|
||||
"SDXL-refiner-0.9": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||
},
|
||||
"SDXL-refiner-1.0": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_1.0.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_img(display=True, key=None, device="cuda"):
|
||||
image = get_interactive_image(key=key)
|
||||
@@ -117,52 +65,48 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
|
||||
def run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
version: str,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
if version.startswith("SDXL-base"):
|
||||
W, H = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
spec: SamplingSpec = state.get("spec")
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
|
||||
params.width, params.height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
else:
|
||||
H = st.number_input("H", value=version_dict["H"], min_value=64, max_value=2048)
|
||||
W = st.number_input("W", value=version_dict["W"], min_value=64, max_value=2048)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
params.height = int(
|
||||
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
|
||||
)
|
||||
params.width = int(
|
||||
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
|
||||
)
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(stage2strength=stage2strength)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
out = model.text_to_image(
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state.get("filter"),
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
@@ -172,51 +116,45 @@ def run_txt2img(
|
||||
|
||||
def run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=False,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
stage2strength=None,
|
||||
):
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
|
||||
img = load_img()
|
||||
if img is None:
|
||||
return None
|
||||
H, W = img.shape[2], img.shape[3]
|
||||
params.height, params.width = img.shape[2], img.shape[3]
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input(
|
||||
params.img2img_strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
params, num_rows, num_cols = init_sampling(params=params)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
outputs = st.empty()
|
||||
st.text("Sampling")
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
out = model.image_to_image(
|
||||
image=repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=int(num_samples),
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
noise_strength=stage2strength,
|
||||
filter=state.get("filter"),
|
||||
)
|
||||
|
||||
show_samples(out, outputs)
|
||||
return out
|
||||
|
||||
@@ -224,39 +162,29 @@ def run_img2img(
|
||||
def apply_refiner(
|
||||
input,
|
||||
state,
|
||||
sampler,
|
||||
num_samples,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
filter=None,
|
||||
num_samples: int,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
finish_denoising=False,
|
||||
):
|
||||
init_dict = {
|
||||
"orig_width": input.shape[3] * 8,
|
||||
"orig_height": input.shape[2] * 8,
|
||||
"target_width": input.shape[3] * 8,
|
||||
"target_height": input.shape[2] * 8,
|
||||
}
|
||||
model: SamplingPipeline = state.get("model")
|
||||
params: SamplingParams = state.get("params")
|
||||
|
||||
value_dict = init_dict
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
value_dict["crop_coords_top"] = 0
|
||||
value_dict["crop_coords_left"] = 0
|
||||
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
params.orig_width = input.shape[3] * 8
|
||||
params.orig_height = input.shape[2] * 8
|
||||
params.width = input.shape[3] * 8
|
||||
params.height = input.shape[2] * 8
|
||||
|
||||
st.warning(f"refiner input shape: {input.shape}")
|
||||
samples = do_img2img(
|
||||
input,
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
skip_encode=True,
|
||||
filter=filter,
|
||||
|
||||
samples = model.refiner(
|
||||
image=input,
|
||||
params=params,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
samples=num_samples,
|
||||
return_latents=False,
|
||||
filter=state.get("filter"),
|
||||
add_noise=not finish_denoising,
|
||||
)
|
||||
|
||||
@@ -265,28 +193,34 @@ def apply_refiner(
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[member.value for member in ModelArchitecture],
|
||||
0,
|
||||
)
|
||||
version_enum = ModelArchitecture(version)
|
||||
specs = model_specs[version_enum]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
|
||||
if version.startswith("SDXL-base"):
|
||||
if str(version).startswith("stable-diffusion-xl"):
|
||||
add_pipeline = st.checkbox("Load SDXL-refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed = int(
|
||||
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
state = init_st(model_specs[version_enum], load_filter=True)
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
is_legacy = specs.is_legacy
|
||||
|
||||
prompt = st.text_input(
|
||||
"prompt",
|
||||
@@ -302,46 +236,59 @@ if __name__ == "__main__":
|
||||
|
||||
if add_pipeline:
|
||||
st.write("__________________________")
|
||||
version2 = st.selectbox("Refiner:", ["SDXL-refiner-1.0", "SDXL-refiner-0.9"])
|
||||
version2 = ModelArchitecture(
|
||||
st.selectbox(
|
||||
"Refiner:",
|
||||
[
|
||||
ModelArchitecture.SDXL_V1_REFINER.value,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER.value,
|
||||
],
|
||||
)
|
||||
)
|
||||
st.warning(
|
||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||
)
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2, load_filter=False)
|
||||
specs2 = model_specs[version2]
|
||||
state2 = init_st(specs2, load_filter=False)
|
||||
params2 = state2["params"]
|
||||
|
||||
stage2strength = st.number_input(
|
||||
params2.img2img_strength = st.number_input(
|
||||
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||
)
|
||||
|
||||
sampler2, *_ = init_sampling(
|
||||
params2, *_ = init_sampling(
|
||||
key=2,
|
||||
img2img_strength=stage2strength,
|
||||
params=state2["params"],
|
||||
specify_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
finish_denoising = st.checkbox("Finish denoising with refiner.", True)
|
||||
if not finish_denoising:
|
||||
if finish_denoising:
|
||||
stage2strength = params2.img2img_strength
|
||||
else:
|
||||
stage2strength = None
|
||||
else:
|
||||
state2 = None
|
||||
params2 = None
|
||||
stage2strength = None
|
||||
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
version=str(version),
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "img2img":
|
||||
out = run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
state=state,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
else:
|
||||
@@ -356,13 +303,11 @@ if __name__ == "__main__":
|
||||
outputs = st.empty()
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
state2,
|
||||
sampler2,
|
||||
samples_z.shape[0],
|
||||
input=samples_z,
|
||||
state=state2,
|
||||
num_samples=samples_z.shape[0],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=state.get("filter"),
|
||||
finish_denoising=finish_denoising,
|
||||
)
|
||||
show_samples(samples, outputs)
|
||||
|
||||
Reference in New Issue
Block a user