Files
generative-models/scripts/demo/streamlit_helpers.py
Stephan Auerhahn a009aa8a9f adding some typing
2023-08-09 13:27:30 -07:00

274 lines
8.0 KiB
Python

import os
import numpy as np
import streamlit as st
import torch
from einops import rearrange, repeat
from PIL import Image
from torchvision import transforms
from typing import Optional, Tuple, Dict, Any
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.api import (
Discretization,
Guider,
Sampler,
SamplingParams,
SamplingSpec,
SamplingPipeline,
Thresholder,
)
from sgm.inference.helpers import (
embed_watermark,
)
@st.cache_resource()
def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]:
global lowvram_mode
state: Dict[str, Any] = dict()
if not "model" in state:
config = spec.config
ckpt = spec.ckpt
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=lowvram_mode,
device="cpu" if lowvram_mode else "cuda",
)
state["spec"] = spec
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = SamplingParams()
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(
keys, params: SamplingParams, prompt=None, negative_prompt=None
) -> SamplingParams:
for key in keys:
if key == "txt":
if prompt is None:
prompt = st.text_input(
"Prompt", "A professional photograph of an astronaut riding a pig"
)
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=params.orig_width,
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=params.orig_height,
min_value=16,
)
params.orig_width = int(orig_width)
params.orig_height = int(orig_height)
if key == "crop_coords_top_left":
crop_coord_top = st.number_input(
"crop_coords_top", value=params.crop_coords_top, min_value=0
)
crop_coord_left = st.number_input(
"crop_coords_left", value=params.crop_coords_left, min_value=0
)
params.crop_coords_top = int(crop_coord_top)
params.crop_coords_left = int(crop_coord_left)
return params
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def init_save_locally(_dir, init_value: bool = False):
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
if save_locally:
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
else:
save_path = None
return save_locally, save_path
def show_samples(samples, outputs):
if isinstance(samples, tuple):
samples, _ = samples
grid = embed_watermark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
def get_guider(key, params: SamplingParams) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
)
if params.guider == Guider.VANILLA:
scale = st.number_input(
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
)
params.scale = scale
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
"None",
],
)
if thresholder == "None":
params.thresholder = Thresholder.NONE
else:
raise NotImplementedError
return params
def init_sampling(
key=1,
params: SamplingParams = SamplingParams(),
specify_num_samples=True,
) -> Tuple[SamplingParams, int, int]:
params = SamplingParams(img2img_strength=params.img2img_strength)
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
)
params.sampler = Sampler(
st.sidebar.selectbox(
f"Sampler #{key}",
[member.value for member in Sampler],
0,
)
)
params.discretization = Discretization(
st.sidebar.selectbox(
f"Discretization #{key}",
[member.value for member in Discretization],
)
)
params = get_discretization(params, key=key)
params = get_guider(key=key, params=params)
params = get_sampler(params, key=key)
return params, num_rows, num_cols
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
)
params.s_tmin = st.sidebar.number_input(
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
)
params.s_tmax = st.sidebar.number_input(
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
)
params.s_noise = st.sidebar.number_input(
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
elif (
params.sampler == Sampler.EULER_ANCESTRAL
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
return params
def get_interactive_image(key=None) -> Optional[Image.Image]:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
return None
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 2.0 - 1.0),
]
)
img = transform(image)[None, ...]
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img
def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image