Files
generative-models/scripts/demo/streamlit_helpers.py
2023-08-06 11:30:40 +00:00

374 lines
12 KiB
Python

import os
import numpy as np
import streamlit as st
import torch
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.inference.helpers import (
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
embed_watermark,
)
from sgm.util import load_model_from_config
@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
state = dict()
if not "model" in state:
config = version_dict["config"]
ckpt = version_dict["ckpt"]
config = OmegaConf.load(config)
model = load_model_from_config(
config, ckpt if load_ckpt else None, freeze=False
)
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
return state
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def initial_model_load(model):
global lowvram_mode
if lowvram_mode:
model.model.half()
else:
model.cuda()
return model
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
for key in keys:
if key == "txt":
if prompt is None:
prompt = st.text_input(
"Prompt", "A professional photograph of an astronaut riding a pig"
)
if negative_prompt is None:
negative_prompt = st.text_input("Negative prompt", "")
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = st.number_input(
"orig_width",
value=init_dict["orig_width"],
min_value=16,
)
orig_height = st.number_input(
"orig_height",
value=init_dict["orig_height"],
min_value=16,
)
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left":
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
return value_dict
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
def init_save_locally(_dir, init_value: bool = False):
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
if save_locally:
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
else:
save_path = None
return save_locally, save_path
def show_samples(samples, outputs):
if isinstance(samples, tuple):
samples, _ = samples
grid = embed_watermark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
def get_guider(key):
guider = st.sidebar.selectbox(
f"Discretization #{key}",
[
"VanillaCFG",
"IdentityGuider",
],
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
)
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
[
"None",
],
)
if thresholder == "None":
dyn_thresh_config = {
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
}
else:
raise NotImplementedError
return guider_config
def init_sampling(
key=1,
img2img_strength=1.0,
specify_num_samples=True,
stage2strength=None,
):
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
],
0,
)
discretization = st.sidebar.selectbox(
f"Discretization #{key}",
[
"LegacyDDPMDiscretization",
"EDMDiscretization",
],
)
discretization_config = get_discretization(discretization, key=key)
guider_config = get_guider(key=key)
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=steps
)
return sampler, num_rows, num_cols
def get_discretization(discretization, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
rho = st.number_input(f"rho #{key}", value=3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=True,
)
elif sampler_name == "LinearMultistepSampler":
order = st.sidebar.number_input("order", value=4, min_value=1)
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=True,
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
def get_interactive_image(key=None) -> Image.Image:
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
if image is not None:
image = Image.open(image)
if not image.mode == "RGB":
image = image.convert("RGB")
return image
def load_img(display=True, key=None):
image = get_interactive_image(key=key)
if image is None:
return None
if display:
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 2.0 - 1.0),
]
)
img = transform(image)[None, ...]
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
return img
def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image