mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
269 lines
8.0 KiB
Python
269 lines
8.0 KiB
Python
import os
|
|
|
|
import numpy as np
|
|
import streamlit as st
|
|
import torch
|
|
from einops import rearrange, repeat
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from typing import Optional, Tuple, Dict, Any
|
|
|
|
|
|
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
|
|
|
from sgm.inference.api import (
|
|
Discretization,
|
|
Guider,
|
|
Sampler,
|
|
SamplingParams,
|
|
SamplingSpec,
|
|
SamplingPipeline,
|
|
Thresholder,
|
|
)
|
|
from sgm.inference.helpers import embed_watermark, CudaModelManager
|
|
|
|
|
|
@st.cache_resource()
|
|
def init_st(
|
|
spec: SamplingSpec,
|
|
load_ckpt=True,
|
|
load_filter=True,
|
|
use_fp16=True,
|
|
enable_swap=True,
|
|
) -> Dict[str, Any]:
|
|
state: Dict[str, Any] = dict()
|
|
config = spec.config
|
|
ckpt = spec.ckpt
|
|
|
|
if enable_swap:
|
|
pipeline = SamplingPipeline(
|
|
model_spec=spec,
|
|
use_fp16=use_fp16,
|
|
device=CudaModelManager(device="cuda", swap_device="cpu"),
|
|
)
|
|
else:
|
|
pipeline = SamplingPipeline(model_spec=spec, use_fp16=use_fp16)
|
|
|
|
state["spec"] = spec
|
|
state["model"] = pipeline
|
|
state["ckpt"] = ckpt if load_ckpt else None
|
|
state["config"] = config
|
|
state["params"] = spec.default_params
|
|
if load_filter:
|
|
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
|
else:
|
|
state["filter"] = None
|
|
return state
|
|
|
|
|
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
|
return list(set([x.input_key for x in conditioner.embedders]))
|
|
|
|
|
|
def init_embedder_options(
|
|
keys, params: SamplingParams, prompt=None, negative_prompt=None
|
|
) -> SamplingParams:
|
|
for key in keys:
|
|
if key == "txt":
|
|
if prompt is None:
|
|
prompt = st.text_input(
|
|
"Prompt", "A professional photograph of an astronaut riding a pig"
|
|
)
|
|
if negative_prompt is None:
|
|
negative_prompt = st.text_input("Negative prompt", "")
|
|
|
|
if key == "original_size_as_tuple":
|
|
orig_width = st.number_input(
|
|
"orig_width",
|
|
value=params.orig_width,
|
|
min_value=16,
|
|
)
|
|
orig_height = st.number_input(
|
|
"orig_height",
|
|
value=params.orig_height,
|
|
min_value=16,
|
|
)
|
|
|
|
params.orig_width = int(orig_width)
|
|
params.orig_height = int(orig_height)
|
|
|
|
if key == "crop_coords_top_left":
|
|
crop_coord_top = st.number_input(
|
|
"crop_coords_top", value=params.crop_coords_top, min_value=0
|
|
)
|
|
crop_coord_left = st.number_input(
|
|
"crop_coords_left", value=params.crop_coords_left, min_value=0
|
|
)
|
|
|
|
params.crop_coords_top = int(crop_coord_top)
|
|
params.crop_coords_left = int(crop_coord_left)
|
|
return params
|
|
|
|
|
|
def perform_save_locally(save_path, samples):
|
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
|
base_count = len(os.listdir(os.path.join(save_path)))
|
|
samples = embed_watermark(samples)
|
|
for sample in samples:
|
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
|
Image.fromarray(sample.astype(np.uint8)).save(
|
|
os.path.join(save_path, f"{base_count:09}.png")
|
|
)
|
|
base_count += 1
|
|
|
|
|
|
def init_save_locally(_dir, init_value: bool = False):
|
|
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
|
if save_locally:
|
|
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
|
|
else:
|
|
save_path = None
|
|
|
|
return save_locally, save_path
|
|
|
|
|
|
def show_samples(samples, outputs):
|
|
if isinstance(samples, tuple):
|
|
samples, _ = samples
|
|
grid = embed_watermark(torch.stack([samples]))
|
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
|
outputs.image(grid.cpu().numpy())
|
|
|
|
|
|
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
|
params.guider = Guider(
|
|
st.sidebar.selectbox(
|
|
f"Discretization #{key}", [member.value for member in Guider]
|
|
)
|
|
)
|
|
|
|
if params.guider == Guider.VANILLA:
|
|
scale = st.number_input(
|
|
f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0
|
|
)
|
|
params.scale = scale
|
|
thresholder = st.sidebar.selectbox(
|
|
f"Thresholder #{key}",
|
|
[
|
|
"None",
|
|
],
|
|
)
|
|
|
|
if thresholder == "None":
|
|
params.thresholder = Thresholder.NONE
|
|
else:
|
|
raise NotImplementedError
|
|
return params
|
|
|
|
|
|
def init_sampling(
|
|
params: SamplingParams,
|
|
key=1,
|
|
specify_num_samples=True,
|
|
) -> Tuple[SamplingParams, int, int]:
|
|
params = SamplingParams(img2img_strength=params.img2img_strength)
|
|
|
|
num_rows, num_cols = 1, 1
|
|
if specify_num_samples:
|
|
num_cols = st.number_input(
|
|
f"num cols #{key}", value=2, min_value=1, max_value=10
|
|
)
|
|
|
|
params.steps = int(
|
|
st.sidebar.number_input(
|
|
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
|
)
|
|
)
|
|
|
|
params.sampler = Sampler(
|
|
st.sidebar.selectbox(
|
|
f"Sampler #{key}",
|
|
[member.value for member in Sampler],
|
|
0,
|
|
)
|
|
)
|
|
params.discretization = Discretization(
|
|
st.sidebar.selectbox(
|
|
f"Discretization #{key}",
|
|
[member.value for member in Discretization],
|
|
)
|
|
)
|
|
|
|
params = get_discretization(params=params, key=key)
|
|
params = get_guider(params=params, key=key)
|
|
params = get_sampler(params=params, key=key)
|
|
|
|
return params, num_rows, num_cols
|
|
|
|
|
|
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
|
if params.discretization == Discretization.EDM:
|
|
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
|
|
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
|
|
params.rho = st.number_input(f"rho #{key}", value=params.rho)
|
|
return params
|
|
|
|
|
|
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
|
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
|
|
params.s_churn = st.sidebar.number_input(
|
|
f"s_churn #{key}", value=params.s_churn, min_value=0.0
|
|
)
|
|
params.s_tmin = st.sidebar.number_input(
|
|
f"s_tmin #{key}", value=params.s_tmin, min_value=0.0
|
|
)
|
|
params.s_tmax = st.sidebar.number_input(
|
|
f"s_tmax #{key}", value=params.s_tmax, min_value=0.0
|
|
)
|
|
params.s_noise = st.sidebar.number_input(
|
|
f"s_noise #{key}", value=params.s_noise, min_value=0.0
|
|
)
|
|
|
|
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
|
params.s_noise = st.sidebar.number_input(
|
|
"s_noise", value=params.s_noise, min_value=0.0
|
|
)
|
|
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
|
|
|
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
|
params.order = int(
|
|
st.sidebar.number_input("order", value=params.order, min_value=1)
|
|
)
|
|
return params
|
|
|
|
|
|
def get_interactive_image(key=None) -> Optional[Image.Image]:
|
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
|
if image is not None:
|
|
image = Image.open(image)
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
return image
|
|
return None
|
|
|
|
|
|
def load_img(display=True, key=None) -> Optional[torch.Tensor]:
|
|
image = get_interactive_image(key=key)
|
|
if image is None:
|
|
return None
|
|
if display:
|
|
st.image(image)
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h})")
|
|
|
|
transform = transforms.Compose(
|
|
[
|
|
transforms.ToTensor(),
|
|
transforms.Lambda(lambda x: x * 2.0 - 1.0),
|
|
]
|
|
)
|
|
img = transform(image)[None, ...]
|
|
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
|
return img
|
|
|
|
|
|
def get_init_img(batch_size=1, key=None):
|
|
init_image = load_img(key=key).cuda()
|
|
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
|
return init_image
|