mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-04 14:14:24 +01:00
soon is now
This commit is contained in:
157
scripts/demo/detect.py
Normal file
157
scripts/demo/detect.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from imwatermark import WatermarkDecoder
|
||||
except ImportError as e:
|
||||
try:
|
||||
# Assume some of the other dependencies such as torch are not fulfilled
|
||||
# import file without loading unnecessary libraries.
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
spec = importlib.util.find_spec("imwatermark.maxDct")
|
||||
assert spec is not None
|
||||
maxDct = importlib.util.module_from_spec(spec)
|
||||
sys.modules["maxDct"] = maxDct
|
||||
spec.loader.exec_module(maxDct)
|
||||
|
||||
class WatermarkDecoder(object):
|
||||
"""A minimal version of
|
||||
https://github.com/ShieldMnt/invisible-watermark/blob/main/imwatermark/watermark.py
|
||||
to only reconstruct bits using dwtDct"""
|
||||
|
||||
def __init__(self, wm_type="bytes", length=0):
|
||||
assert wm_type == "bits", "Only bits defined in minimal import"
|
||||
self._wmType = wm_type
|
||||
self._wmLen = length
|
||||
|
||||
def reconstruct(self, bits):
|
||||
if len(bits) != self._wmLen:
|
||||
raise RuntimeError("bits are not matched with watermark length")
|
||||
|
||||
return bits
|
||||
|
||||
def decode(self, cv2Image, method="dwtDct", **configs):
|
||||
(r, c, channels) = cv2Image.shape
|
||||
if r * c < 256 * 256:
|
||||
raise RuntimeError("image too small, should be larger than 256x256")
|
||||
|
||||
bits = []
|
||||
assert method == "dwtDct"
|
||||
embed = maxDct.EmbedMaxDct(watermarks=[], wmLen=self._wmLen, **configs)
|
||||
bits = embed.decode(cv2Image)
|
||||
return self.reconstruct(bits)
|
||||
|
||||
except:
|
||||
raise e
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
MATCH_VALUES = [
|
||||
[27, "No watermark detected"],
|
||||
[33, "Partial watermark match. Cannot determine with certainty."],
|
||||
[
|
||||
35,
|
||||
(
|
||||
"Likely watermarked. In our test 0.02% of real images were "
|
||||
'falsely detected as "Likely watermarked"'
|
||||
),
|
||||
],
|
||||
[
|
||||
49,
|
||||
(
|
||||
"Very likely watermarked. In our test no real images were "
|
||||
'falsely detected as "Very likely watermarked"'
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
class GetWatermarkMatch:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(self.watermark)
|
||||
self.decoder = WatermarkDecoder("bits", self.num_bits)
|
||||
|
||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Detects the number of matching bits the predefined watermark with one
|
||||
or multiple images. Images should be in cv2 format, e.g. h x w x c.
|
||||
|
||||
Args:
|
||||
x: ([B], h w, c) in range [0, 255]
|
||||
|
||||
Returns:
|
||||
number of matched bits ([B],)
|
||||
"""
|
||||
squeeze = len(x.shape) == 3
|
||||
if squeeze:
|
||||
x = x[None, ...]
|
||||
x = np.flip(x, axis=-1)
|
||||
|
||||
bs = x.shape[0]
|
||||
detected = np.empty((bs, self.num_bits), dtype=bool)
|
||||
for k in range(bs):
|
||||
detected[k] = self.decoder.decode(x[k], "dwtDct")
|
||||
result = np.sum(detected == self.watermark, axis=-1)
|
||||
if squeeze:
|
||||
return result[0]
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
get_watermark_match = GetWatermarkMatch(WATERMARK_BITS)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"filename",
|
||||
nargs="+",
|
||||
type=str,
|
||||
help="Image files to check for watermarks",
|
||||
)
|
||||
opts = parser.parse_args()
|
||||
|
||||
print(
|
||||
"""
|
||||
This script tries to detect watermarked images. Please be aware of
|
||||
the following:
|
||||
- As the watermark is supposed to be invisible, there is the risk that
|
||||
watermarked images may not be detected.
|
||||
- To maximize the chance of detection make sure that the image has the same
|
||||
dimensions as when the watermark was applied (most likely 1024x1024
|
||||
or 512x512).
|
||||
- Specific image manipulation may drastically decrease the chance that
|
||||
watermarks can be detected.
|
||||
- There is also the chance that an image has the characteristics of the
|
||||
watermark by chance.
|
||||
- The watermark script is public, anybody may watermark any images, and
|
||||
could therefore claim it to be generated.
|
||||
- All numbers below are based on a test using 10,000 images without any
|
||||
modifications after applying the watermark.
|
||||
"""
|
||||
)
|
||||
|
||||
for fn in opts.filename:
|
||||
image = cv2.imread(fn)
|
||||
if image is None:
|
||||
print(f"Couldn't read {fn}. Skipping")
|
||||
continue
|
||||
|
||||
num_bits = get_watermark_match(image)
|
||||
k = 0
|
||||
while num_bits > MATCH_VALUES[k][0]:
|
||||
k += 1
|
||||
print(
|
||||
f"{fn}: {MATCH_VALUES[k][1]}",
|
||||
f"Bits that matched the watermark {num_bits} from {len(WATERMARK_BITS)}\n",
|
||||
sep="\n\t",
|
||||
)
|
||||
328
scripts/demo/sampling.py
Normal file
328
scripts/demo/sampling.py
Normal file
@@ -0,0 +1,328 @@
|
||||
from pytorch_lightning import seed_everything
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
|
||||
SAVE_PATH = "outputs/demo/txt2img/"
|
||||
|
||||
SD_XL_BASE_RATIOS = {
|
||||
"0.5": (704, 1408),
|
||||
"0.52": (704, 1344),
|
||||
"0.57": (768, 1344),
|
||||
"0.6": (768, 1280),
|
||||
"0.68": (832, 1216),
|
||||
"0.72": (832, 1152),
|
||||
"0.78": (896, 1152),
|
||||
"0.82": (896, 1088),
|
||||
"0.88": (960, 1088),
|
||||
"0.94": (960, 1024),
|
||||
"1.0": (1024, 1024),
|
||||
"1.07": (1024, 960),
|
||||
"1.13": (1088, 960),
|
||||
"1.21": (1088, 896),
|
||||
"1.29": (1152, 896),
|
||||
"1.38": (1152, 832),
|
||||
"1.46": (1216, 832),
|
||||
"1.67": (1280, 768),
|
||||
"1.75": (1344, 768),
|
||||
"1.91": (1344, 704),
|
||||
"2.0": (1408, 704),
|
||||
"2.09": (1472, 704),
|
||||
"2.4": (1536, 640),
|
||||
"2.5": (1600, 640),
|
||||
"2.89": (1664, 576),
|
||||
"3.0": (1728, 576),
|
||||
}
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SD-XL base": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_base_0.9.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
"sd-2.1": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1.yaml",
|
||||
"ckpt": "checkpoints/v2-1_512-ema-pruned.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
"sd-2.1-768": {
|
||||
"H": 768,
|
||||
"W": 768,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_2_1_768.yaml",
|
||||
"ckpt": "checkpoints/v2-1_768-ema-pruned.safetensors",
|
||||
},
|
||||
"SDXL-Refiner": {
|
||||
"H": 1024,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": True,
|
||||
"config": "configs/inference/sd_xl_refiner.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_refiner_0.9.safetensors",
|
||||
"is_guided": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_img(display=True, key=None, device="cuda"):
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||
return image.to(device)
|
||||
|
||||
|
||||
def run_txt2img(
|
||||
state, version, version_dict, is_legacy=False, return_latents=False, filter=None
|
||||
):
|
||||
if version == "SD-XL base":
|
||||
ratio = st.sidebar.selectbox("Ratio:", list(SD_XL_BASE_RATIOS.keys()), 10)
|
||||
W, H = SD_XL_BASE_RATIOS[ratio]
|
||||
else:
|
||||
H = st.sidebar.number_input(
|
||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||
)
|
||||
W = st.sidebar.number_input(
|
||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||
)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
num_rows, num_cols, sampler = init_sampling(
|
||||
use_identity_guider=not version_dict["is_guided"]
|
||||
)
|
||||
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
st.write(f"**Model I:** {version}")
|
||||
out = do_sample(
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def run_img2img(
|
||||
state, version_dict, is_legacy=False, return_latents=False, filter=None
|
||||
):
|
||||
img = load_img()
|
||||
if img is None:
|
||||
return None
|
||||
H, W = img.shape[2], img.shape[3]
|
||||
|
||||
init_dict = {
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
}
|
||||
value_dict = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner),
|
||||
init_dict,
|
||||
)
|
||||
strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.5, min_value=0.0, max_value=1.0
|
||||
)
|
||||
num_rows, num_cols, sampler = init_sampling(
|
||||
img2img_strength=strength,
|
||||
use_identity_guider=not version_dict["is_guided"],
|
||||
)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
if st.button("Sample"):
|
||||
out = do_img2img(
|
||||
repeat(img, "1 ... -> n ...", n=num_samples),
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=["txt"] if not is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def apply_refiner(
|
||||
input,
|
||||
state,
|
||||
sampler,
|
||||
num_samples,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
filter=None,
|
||||
):
|
||||
init_dict = {
|
||||
"orig_width": input.shape[3] * 8,
|
||||
"orig_height": input.shape[2] * 8,
|
||||
"target_width": input.shape[3] * 8,
|
||||
"target_height": input.shape[2] * 8,
|
||||
}
|
||||
|
||||
value_dict = init_dict
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
value_dict["crop_coords_top"] = 0
|
||||
value_dict["crop_coords_left"] = 0
|
||||
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
st.warning(f"refiner input shape: {input.shape}")
|
||||
samples = do_img2img(
|
||||
input,
|
||||
state["model"],
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
skip_encode=True,
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
return samples
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
st.write("__________________________")
|
||||
|
||||
if version == "SD-XL base":
|
||||
add_pipeline = st.checkbox("Load SDXL-Refiner?", False)
|
||||
st.write("__________________________")
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
filter = DeepFloydDataFiltering(verbose=False)
|
||||
|
||||
seed = st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
|
||||
prompt = st.text_input(
|
||||
"prompt",
|
||||
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
||||
)
|
||||
if is_legacy:
|
||||
negative_prompt = st.text_input("negative prompt", "")
|
||||
else:
|
||||
negative_prompt = "" # which is unused
|
||||
|
||||
if add_pipeline:
|
||||
st.write("__________________________")
|
||||
|
||||
version2 = "SDXL-Refiner"
|
||||
st.warning(
|
||||
f"Running with {version2} as the second stage model. Make sure to provide (V)RAM :) "
|
||||
)
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2)
|
||||
st.info(state2["msg"])
|
||||
|
||||
stage2strength = st.number_input(
|
||||
"**Refinement strength**", value=0.3, min_value=0.0, max_value=1.0
|
||||
)
|
||||
|
||||
sampler2 = init_sampling(
|
||||
key=2,
|
||||
img2img_strength=stage2strength,
|
||||
use_identity_guider=not version_dict["is_guided"],
|
||||
get_num_samples=False,
|
||||
)
|
||||
st.write("__________________________")
|
||||
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
state,
|
||||
version,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
return_latents=add_pipeline,
|
||||
filter=filter,
|
||||
)
|
||||
elif mode == "img2img":
|
||||
out = run_img2img(
|
||||
state,
|
||||
version_dict,
|
||||
is_legacy=is_legacy,
|
||||
return_latents=add_pipeline,
|
||||
filter=filter,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
if isinstance(out, (tuple, list)):
|
||||
samples, samples_z = out
|
||||
else:
|
||||
samples = out
|
||||
|
||||
if add_pipeline:
|
||||
st.write("**Running Refinement Stage**")
|
||||
samples = apply_refiner(
|
||||
samples_z,
|
||||
state2,
|
||||
sampler2,
|
||||
samples_z.shape[0],
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt if is_legacy else "",
|
||||
filter=filter,
|
||||
)
|
||||
|
||||
if save_locally and samples is not None:
|
||||
perform_save_locally(save_path, samples)
|
||||
668
scripts/demo/streamlit_helpers.py
Normal file
668
scripts/demo/streamlit_helpers.py
Normal file
@@ -0,0 +1,668 @@
|
||||
import os
|
||||
from typing import Union, List
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import rearrange, repeat
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import OmegaConf, ListConfig
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import make_grid
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
EulerAncestralSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import append_dims
|
||||
from sgm.util import instantiate_from_config
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
def init_st(version_dict, load_ckpt=True):
|
||||
state = dict()
|
||||
if not "model" in state:
|
||||
config = version_dict["config"]
|
||||
ckpt = version_dict["ckpt"]
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
||||
|
||||
state["msg"] = msg
|
||||
state["model"] = model
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
return state
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt=None, verbose=True):
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
if ckpt is not None:
|
||||
print(f"Loading model from {ckpt}")
|
||||
if ckpt.endswith("ckpt"):
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
global_step = pl_sd["global_step"]
|
||||
st.info(f"loaded ckpt from global step {global_step}")
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
elif ckpt.endswith("safetensors"):
|
||||
sd = load_safetensors(ckpt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
msg = None
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
else:
|
||||
msg = None
|
||||
|
||||
model.cuda()
|
||||
model.eval()
|
||||
return model, msg
|
||||
|
||||
|
||||
def get_unique_embedder_keys_from_conditioner(conditioner):
|
||||
return list(set([x.input_key for x in conditioner.embedders]))
|
||||
|
||||
|
||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
# Hardcoded demo settings; might undergo some changes in the future
|
||||
|
||||
value_dict = {}
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
if prompt is None:
|
||||
prompt = st.text_input(
|
||||
"Prompt", "A professional photograph of an astronaut riding a pig"
|
||||
)
|
||||
if negative_prompt is None:
|
||||
negative_prompt = st.text_input("Negative prompt", "")
|
||||
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
|
||||
if key == "original_size_as_tuple":
|
||||
orig_width = st.number_input(
|
||||
"orig_width",
|
||||
value=init_dict["orig_width"],
|
||||
min_value=16,
|
||||
)
|
||||
orig_height = st.number_input(
|
||||
"orig_height",
|
||||
value=init_dict["orig_height"],
|
||||
min_value=16,
|
||||
)
|
||||
|
||||
value_dict["orig_width"] = orig_width
|
||||
value_dict["orig_height"] = orig_height
|
||||
|
||||
if key == "crop_coords_top_left":
|
||||
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
|
||||
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
|
||||
|
||||
value_dict["crop_coords_top"] = crop_coord_top
|
||||
value_dict["crop_coords_left"] = crop_coord_left
|
||||
|
||||
if key == "aesthetic_score":
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
if key == "target_size_as_tuple":
|
||||
target_width = st.number_input(
|
||||
"target_width",
|
||||
value=init_dict["target_width"],
|
||||
min_value=16,
|
||||
)
|
||||
target_height = st.number_input(
|
||||
"target_height",
|
||||
value=init_dict["target_height"],
|
||||
min_value=16,
|
||||
)
|
||||
|
||||
value_dict["target_width"] = target_width
|
||||
value_dict["target_height"] = target_height
|
||||
|
||||
return value_dict
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||
base_count = len(os.listdir(os.path.join(save_path)))
|
||||
samples = embed_watemark(samples)
|
||||
for sample in samples:
|
||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(sample.astype(np.uint8)).save(
|
||||
os.path.join(save_path, f"{base_count:09}.png")
|
||||
)
|
||||
base_count += 1
|
||||
|
||||
|
||||
def init_save_locally(_dir, init_value: bool = False):
|
||||
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
||||
if save_locally:
|
||||
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
|
||||
else:
|
||||
save_path = None
|
||||
|
||||
return save_locally, save_path
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"VanillaCFG",
|
||||
"IdentityGuider",
|
||||
],
|
||||
)
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
)
|
||||
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
[
|
||||
"None",
|
||||
],
|
||||
)
|
||||
|
||||
if thresholder == "None":
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return guider_config
|
||||
|
||||
|
||||
def init_sampling(
|
||||
key=1, img2img_strength=1.0, use_identity_guider=False, get_num_samples=True
|
||||
):
|
||||
if get_num_samples:
|
||||
num_rows = 1
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=50, min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[
|
||||
"EulerEDMSampler",
|
||||
"HeunEDMSampler",
|
||||
"EulerAncestralSampler",
|
||||
"DPMPP2SAncestralSampler",
|
||||
"DPMPP2MSampler",
|
||||
"LinearMultistepSampler",
|
||||
],
|
||||
0,
|
||||
)
|
||||
discretization = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"LegacyDDPMDiscretization",
|
||||
"EDMDiscretization",
|
||||
],
|
||||
)
|
||||
|
||||
discretization_config = get_discretization(discretization, key=key)
|
||||
|
||||
guider_config = get_guider(key=key)
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
)
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization, strength=img2img_strength
|
||||
)
|
||||
if get_num_samples:
|
||||
return num_rows, num_cols, sampler
|
||||
return sampler
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
"params": {"legacy_range": not use_new_range},
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
"sigma_min": sigma_min,
|
||||
"sigma_max": sigma_max,
|
||||
"rho": rho,
|
||||
},
|
||||
}
|
||||
|
||||
return discretization_config
|
||||
|
||||
|
||||
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
|
||||
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
|
||||
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
|
||||
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
|
||||
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
|
||||
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerEDMSampler":
|
||||
sampler = EulerEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "HeunEDMSampler":
|
||||
sampler = HeunEDMSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
s_churn=s_churn,
|
||||
s_tmin=s_tmin,
|
||||
s_tmax=s_tmax,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif (
|
||||
sampler_name == "EulerAncestralSampler"
|
||||
or sampler_name == "DPMPP2SAncestralSampler"
|
||||
):
|
||||
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
||||
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
||||
|
||||
if sampler_name == "EulerAncestralSampler":
|
||||
sampler = EulerAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2SAncestralSampler":
|
||||
sampler = DPMPP2SAncestralSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
eta=eta,
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "DPMPP2MSampler":
|
||||
sampler = DPMPP2MSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "LinearMultistepSampler":
|
||||
order = st.sidebar.number_input("order", value=4, min_value=1)
|
||||
sampler = LinearMultistepSampler(
|
||||
num_steps=steps,
|
||||
discretization_config=discretization_config,
|
||||
guider_config=guider_config,
|
||||
order=order,
|
||||
verbose=True,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown sampler {sampler_name}!")
|
||||
|
||||
return sampler
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x * 2.0 - 1.0),
|
||||
]
|
||||
)
|
||||
img = transform(image)[None, ...]
|
||||
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
||||
return img
|
||||
|
||||
|
||||
def get_init_img(batch_size=1, key=None):
|
||||
init_image = load_img(key=key).cuda()
|
||||
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
||||
return init_image
|
||||
|
||||
|
||||
def do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings: List = None,
|
||||
batch2model_input: List = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
|
||||
st.text("Sampling")
|
||||
|
||||
outputs = st.empty()
|
||||
precision_scope = autocast
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||
)
|
||||
|
||||
additional_model_inputs = {}
|
||||
for k in batch2model_input:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to("cuda")
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
batch_uc = {}
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def do_img2img(
|
||||
img,
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=[],
|
||||
additional_kwargs={},
|
||||
offset_noise_level: int = 0.0,
|
||||
return_latents=False,
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
):
|
||||
st.text("Sampling")
|
||||
|
||||
outputs = st.empty()
|
||||
precision_scope = autocast
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[num_samples],
|
||||
)
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
)
|
||||
|
||||
for k in c:
|
||||
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
||||
|
||||
for k in additional_kwargs:
|
||||
c[k] = uc[k] = additional_kwargs[k]
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
z = model.encode_first_stage(img)
|
||||
noise = torch.randn_like(z)
|
||||
sigmas = sampler.discretization(sampler.num_steps)
|
||||
sigma = sigmas[0]
|
||||
|
||||
st.info(f"all sigmas: {sigmas}")
|
||||
st.info(f"noising sigma: {sigma}")
|
||||
|
||||
if offset_noise_level > 0.0:
|
||||
noise = noise + offset_noise_level * append_dims(
|
||||
torch.randn(z.shape[0], device=z.device), z.ndim
|
||||
)
|
||||
noised_z = z + noise * append_dims(sigma, z.ndim)
|
||||
noised_z = noised_z / torch.sqrt(
|
||||
1.0 + sigmas[0] ** 2.0
|
||||
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
||||
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = embed_watemark(torch.stack([samples]))
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
Reference in New Issue
Block a user