mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
929 lines
31 KiB
Python
929 lines
31 KiB
Python
import copy
|
|
import math
|
|
import os
|
|
from glob import glob
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
import streamlit as st
|
|
import torch
|
|
import torch.nn as nn
|
|
import torchvision.transforms as TT
|
|
from einops import rearrange, repeat
|
|
from imwatermark import WatermarkEncoder
|
|
from omegaconf import ListConfig, OmegaConf
|
|
from PIL import Image
|
|
from safetensors.torch import load_file as load_safetensors
|
|
from torch import autocast
|
|
from torchvision import transforms
|
|
from torchvision.utils import make_grid, save_image
|
|
|
|
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
|
|
Txt2NoisyDiscretizationWrapper)
|
|
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
|
DeepFloydDataFiltering
|
|
from sgm.inference.helpers import embed_watermark
|
|
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
|
|
VanillaCFG)
|
|
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
|
DPMPP2SAncestralSampler,
|
|
EulerAncestralSampler,
|
|
EulerEDMSampler,
|
|
HeunEDMSampler,
|
|
LinearMultistepSampler)
|
|
from sgm.util import append_dims, default, instantiate_from_config
|
|
|
|
|
|
@st.cache_resource()
|
|
def init_st(version_dict, load_ckpt=True, load_filter=True):
|
|
state = dict()
|
|
if not "model" in state:
|
|
config = version_dict["config"]
|
|
ckpt = version_dict["ckpt"]
|
|
|
|
config = OmegaConf.load(config)
|
|
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
|
|
|
|
state["msg"] = msg
|
|
state["model"] = model
|
|
state["ckpt"] = ckpt if load_ckpt else None
|
|
state["config"] = config
|
|
if load_filter:
|
|
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
|
return state
|
|
|
|
|
|
def load_model(model):
|
|
model.cuda()
|
|
|
|
|
|
lowvram_mode = False
|
|
|
|
|
|
def set_lowvram_mode(mode):
|
|
global lowvram_mode
|
|
lowvram_mode = mode
|
|
|
|
|
|
def initial_model_load(model):
|
|
global lowvram_mode
|
|
if lowvram_mode:
|
|
model.model.half()
|
|
else:
|
|
model.cuda()
|
|
return model
|
|
|
|
|
|
def unload_model(model):
|
|
global lowvram_mode
|
|
if lowvram_mode:
|
|
model.cpu()
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_model_from_config(config, ckpt=None, verbose=True):
|
|
model = instantiate_from_config(config.model)
|
|
|
|
if ckpt is not None:
|
|
print(f"Loading model from {ckpt}")
|
|
if ckpt.endswith("ckpt"):
|
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
|
if "global_step" in pl_sd:
|
|
global_step = pl_sd["global_step"]
|
|
st.info(f"loaded ckpt from global step {global_step}")
|
|
print(f"Global Step: {pl_sd['global_step']}")
|
|
sd = pl_sd["state_dict"]
|
|
elif ckpt.endswith("safetensors"):
|
|
sd = load_safetensors(ckpt)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
msg = None
|
|
|
|
m, u = model.load_state_dict(sd, strict=False)
|
|
|
|
if len(m) > 0 and verbose:
|
|
print("missing keys:")
|
|
print(m)
|
|
if len(u) > 0 and verbose:
|
|
print("unexpected keys:")
|
|
print(u)
|
|
else:
|
|
msg = None
|
|
|
|
model = initial_model_load(model)
|
|
model.eval()
|
|
return model, msg
|
|
|
|
|
|
def get_unique_embedder_keys_from_conditioner(conditioner):
|
|
return list(set([x.input_key for x in conditioner.embedders]))
|
|
|
|
|
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|
# Hardcoded demo settings; might undergo some changes in the future
|
|
|
|
value_dict = {}
|
|
for key in keys:
|
|
if key == "txt":
|
|
if prompt is None:
|
|
prompt = "A professional photograph of an astronaut riding a pig"
|
|
if negative_prompt is None:
|
|
negative_prompt = ""
|
|
|
|
prompt = st.text_input("Prompt", prompt)
|
|
negative_prompt = st.text_input("Negative prompt", negative_prompt)
|
|
|
|
value_dict["prompt"] = prompt
|
|
value_dict["negative_prompt"] = negative_prompt
|
|
|
|
if key == "original_size_as_tuple":
|
|
orig_width = st.number_input(
|
|
"orig_width",
|
|
value=init_dict["orig_width"],
|
|
min_value=16,
|
|
)
|
|
orig_height = st.number_input(
|
|
"orig_height",
|
|
value=init_dict["orig_height"],
|
|
min_value=16,
|
|
)
|
|
|
|
value_dict["orig_width"] = orig_width
|
|
value_dict["orig_height"] = orig_height
|
|
|
|
if key == "crop_coords_top_left":
|
|
crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0)
|
|
crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0)
|
|
|
|
value_dict["crop_coords_top"] = crop_coord_top
|
|
value_dict["crop_coords_left"] = crop_coord_left
|
|
|
|
if key == "aesthetic_score":
|
|
value_dict["aesthetic_score"] = 6.0
|
|
value_dict["negative_aesthetic_score"] = 2.5
|
|
|
|
if key == "target_size_as_tuple":
|
|
value_dict["target_width"] = init_dict["target_width"]
|
|
value_dict["target_height"] = init_dict["target_height"]
|
|
|
|
if key in ["fps_id", "fps"]:
|
|
fps = st.number_input("fps", value=6, min_value=1)
|
|
|
|
value_dict["fps"] = fps
|
|
value_dict["fps_id"] = fps - 1
|
|
|
|
if key == "motion_bucket_id":
|
|
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
|
|
value_dict["motion_bucket_id"] = mb_id
|
|
|
|
if key == "pool_image":
|
|
st.text("Image for pool conditioning")
|
|
image = load_img(
|
|
key="pool_image_input",
|
|
size=224,
|
|
center_crop=True,
|
|
)
|
|
if image is None:
|
|
st.info("Need an image here")
|
|
image = torch.zeros(1, 3, 224, 224)
|
|
value_dict["pool_image"] = image
|
|
|
|
return value_dict
|
|
|
|
|
|
def perform_save_locally(save_path, samples):
|
|
os.makedirs(os.path.join(save_path), exist_ok=True)
|
|
base_count = len(os.listdir(os.path.join(save_path)))
|
|
samples = embed_watermark(samples)
|
|
for sample in samples:
|
|
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
|
Image.fromarray(sample.astype(np.uint8)).save(
|
|
os.path.join(save_path, f"{base_count:09}.png")
|
|
)
|
|
base_count += 1
|
|
|
|
|
|
def init_save_locally(_dir, init_value: bool = False):
|
|
save_locally = st.sidebar.checkbox("Save images locally", value=init_value)
|
|
if save_locally:
|
|
save_path = st.text_input("Save path", value=os.path.join(_dir, "samples"))
|
|
else:
|
|
save_path = None
|
|
|
|
return save_locally, save_path
|
|
|
|
|
|
def get_guider(options, key):
|
|
guider = st.sidebar.selectbox(
|
|
f"Discretization #{key}",
|
|
[
|
|
"VanillaCFG",
|
|
"IdentityGuider",
|
|
"LinearPredictionGuider",
|
|
],
|
|
options.get("guider", 0),
|
|
)
|
|
|
|
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
|
|
|
|
if guider == "IdentityGuider":
|
|
guider_config = {
|
|
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
|
}
|
|
elif guider == "VanillaCFG":
|
|
scale_schedule = st.sidebar.selectbox(
|
|
f"Scale schedule #{key}",
|
|
["Identity", "Oscillating"],
|
|
)
|
|
|
|
if scale_schedule == "Identity":
|
|
scale = st.number_input(
|
|
f"cfg-scale #{key}",
|
|
value=options.get("cfg", 5.0),
|
|
min_value=0.0,
|
|
)
|
|
|
|
scale_schedule_config = {
|
|
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
|
|
"params": {"scale": scale},
|
|
}
|
|
|
|
elif scale_schedule == "Oscillating":
|
|
small_scale = st.number_input(
|
|
f"small cfg-scale #{key}",
|
|
value=4.0,
|
|
min_value=0.0,
|
|
)
|
|
|
|
large_scale = st.number_input(
|
|
f"large cfg-scale #{key}",
|
|
value=16.0,
|
|
min_value=0.0,
|
|
)
|
|
|
|
sigma_cutoff = st.number_input(
|
|
f"sigma cutoff #{key}",
|
|
value=1.0,
|
|
min_value=0.0,
|
|
)
|
|
|
|
scale_schedule_config = {
|
|
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
|
|
"params": {
|
|
"small_scale": small_scale,
|
|
"large_scale": large_scale,
|
|
"sigma_cutoff": sigma_cutoff,
|
|
},
|
|
}
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
guider_config = {
|
|
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
|
"params": {
|
|
"scale_schedule_config": scale_schedule_config,
|
|
**additional_guider_kwargs,
|
|
},
|
|
}
|
|
elif guider == "LinearPredictionGuider":
|
|
max_scale = st.number_input(
|
|
f"max-cfg-scale #{key}",
|
|
value=options.get("cfg", 1.5),
|
|
min_value=1.0,
|
|
)
|
|
min_scale = st.number_input(
|
|
f"min guidance scale",
|
|
value=options.get("min_cfg", 1.0),
|
|
min_value=1.0,
|
|
max_value=10.0,
|
|
)
|
|
|
|
guider_config = {
|
|
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
|
|
"params": {
|
|
"max_scale": max_scale,
|
|
"min_scale": min_scale,
|
|
"num_frames": options["num_frames"],
|
|
**additional_guider_kwargs,
|
|
},
|
|
}
|
|
else:
|
|
raise NotImplementedError
|
|
return guider_config
|
|
|
|
|
|
def init_sampling(
|
|
key=1,
|
|
img2img_strength: Optional[float] = None,
|
|
specify_num_samples: bool = True,
|
|
stage2strength: Optional[float] = None,
|
|
options: Optional[Dict[str, int]] = None,
|
|
):
|
|
options = {} if options is None else options
|
|
|
|
num_rows, num_cols = 1, 1
|
|
if specify_num_samples:
|
|
num_cols = st.number_input(
|
|
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
|
)
|
|
|
|
steps = st.sidebar.number_input(
|
|
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
|
|
)
|
|
sampler = st.sidebar.selectbox(
|
|
f"Sampler #{key}",
|
|
[
|
|
"EulerEDMSampler",
|
|
"HeunEDMSampler",
|
|
"EulerAncestralSampler",
|
|
"DPMPP2SAncestralSampler",
|
|
"DPMPP2MSampler",
|
|
"LinearMultistepSampler",
|
|
],
|
|
options.get("sampler", 0),
|
|
)
|
|
discretization = st.sidebar.selectbox(
|
|
f"Discretization #{key}",
|
|
[
|
|
"LegacyDDPMDiscretization",
|
|
"EDMDiscretization",
|
|
],
|
|
options.get("discretization", 0),
|
|
)
|
|
|
|
discretization_config = get_discretization(discretization, options=options, key=key)
|
|
|
|
guider_config = get_guider(options=options, key=key)
|
|
|
|
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
|
if img2img_strength is not None:
|
|
st.warning(
|
|
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
|
)
|
|
sampler.discretization = Img2ImgDiscretizationWrapper(
|
|
sampler.discretization, strength=img2img_strength
|
|
)
|
|
if stage2strength is not None:
|
|
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
|
sampler.discretization, strength=stage2strength, original_steps=steps
|
|
)
|
|
return sampler, num_rows, num_cols
|
|
|
|
|
|
def get_discretization(discretization, options, key=1):
|
|
if discretization == "LegacyDDPMDiscretization":
|
|
discretization_config = {
|
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
|
}
|
|
elif discretization == "EDMDiscretization":
|
|
sigma_min = st.number_input(
|
|
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
|
) # 0.0292
|
|
sigma_max = st.number_input(
|
|
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
|
) # 14.6146
|
|
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
|
discretization_config = {
|
|
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
|
"params": {
|
|
"sigma_min": sigma_min,
|
|
"sigma_max": sigma_max,
|
|
"rho": rho,
|
|
},
|
|
}
|
|
|
|
return discretization_config
|
|
|
|
|
|
def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1):
|
|
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
|
|
s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0)
|
|
s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0)
|
|
s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0)
|
|
s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0)
|
|
|
|
if sampler_name == "EulerEDMSampler":
|
|
sampler = EulerEDMSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
s_churn=s_churn,
|
|
s_tmin=s_tmin,
|
|
s_tmax=s_tmax,
|
|
s_noise=s_noise,
|
|
verbose=True,
|
|
)
|
|
elif sampler_name == "HeunEDMSampler":
|
|
sampler = HeunEDMSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
s_churn=s_churn,
|
|
s_tmin=s_tmin,
|
|
s_tmax=s_tmax,
|
|
s_noise=s_noise,
|
|
verbose=True,
|
|
)
|
|
elif (
|
|
sampler_name == "EulerAncestralSampler"
|
|
or sampler_name == "DPMPP2SAncestralSampler"
|
|
):
|
|
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
|
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
|
|
|
if sampler_name == "EulerAncestralSampler":
|
|
sampler = EulerAncestralSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
eta=eta,
|
|
s_noise=s_noise,
|
|
verbose=True,
|
|
)
|
|
elif sampler_name == "DPMPP2SAncestralSampler":
|
|
sampler = DPMPP2SAncestralSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
eta=eta,
|
|
s_noise=s_noise,
|
|
verbose=True,
|
|
)
|
|
elif sampler_name == "DPMPP2MSampler":
|
|
sampler = DPMPP2MSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
verbose=True,
|
|
)
|
|
elif sampler_name == "LinearMultistepSampler":
|
|
order = st.sidebar.number_input("order", value=4, min_value=1)
|
|
sampler = LinearMultistepSampler(
|
|
num_steps=steps,
|
|
discretization_config=discretization_config,
|
|
guider_config=guider_config,
|
|
order=order,
|
|
verbose=True,
|
|
)
|
|
else:
|
|
raise ValueError(f"unknown sampler {sampler_name}!")
|
|
|
|
return sampler
|
|
|
|
|
|
def get_interactive_image() -> Image.Image:
|
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
|
if image is not None:
|
|
image = Image.open(image)
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
return image
|
|
|
|
|
|
def load_img(
|
|
display: bool = True,
|
|
size: Union[None, int, Tuple[int, int]] = None,
|
|
center_crop: bool = False,
|
|
):
|
|
image = get_interactive_image()
|
|
if image is None:
|
|
return None
|
|
if display:
|
|
st.image(image)
|
|
w, h = image.size
|
|
print(f"loaded input image of size ({w}, {h})")
|
|
|
|
transform = []
|
|
if size is not None:
|
|
transform.append(transforms.Resize(size))
|
|
if center_crop:
|
|
transform.append(transforms.CenterCrop(size))
|
|
transform.append(transforms.ToTensor())
|
|
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
|
|
|
|
transform = transforms.Compose(transform)
|
|
img = transform(image)[None, ...]
|
|
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
|
return img
|
|
|
|
|
|
def get_init_img(batch_size=1, key=None):
|
|
init_image = load_img(key=key).cuda()
|
|
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
|
|
return init_image
|
|
|
|
|
|
def do_sample(
|
|
model,
|
|
sampler,
|
|
value_dict,
|
|
num_samples,
|
|
H,
|
|
W,
|
|
C,
|
|
F,
|
|
force_uc_zero_embeddings: Optional[List] = None,
|
|
force_cond_zero_embeddings: Optional[List] = None,
|
|
batch2model_input: List = None,
|
|
return_latents=False,
|
|
filter=None,
|
|
T=None,
|
|
additional_batch_uc_fields=None,
|
|
decoding_t=None,
|
|
):
|
|
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
|
batch2model_input = default(batch2model_input, [])
|
|
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
|
|
|
|
st.text("Sampling")
|
|
|
|
outputs = st.empty()
|
|
precision_scope = autocast
|
|
with torch.no_grad():
|
|
with precision_scope("cuda"):
|
|
with model.ema_scope():
|
|
if T is not None:
|
|
num_samples = [num_samples, T]
|
|
else:
|
|
num_samples = [num_samples]
|
|
|
|
load_model(model.conditioner)
|
|
batch, batch_uc = get_batch(
|
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
value_dict,
|
|
num_samples,
|
|
T=T,
|
|
additional_batch_uc_fields=additional_batch_uc_fields,
|
|
)
|
|
|
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
batch,
|
|
batch_uc=batch_uc,
|
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
|
)
|
|
unload_model(model.conditioner)
|
|
|
|
for k in c:
|
|
if not k == "crossattn":
|
|
c[k], uc[k] = map(
|
|
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
|
)
|
|
if k in ["crossattn", "concat"] and T is not None:
|
|
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
|
|
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
|
|
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
|
|
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
|
|
|
|
additional_model_inputs = {}
|
|
for k in batch2model_input:
|
|
if k == "image_only_indicator":
|
|
assert T is not None
|
|
|
|
if isinstance(
|
|
sampler.guider, (VanillaCFG, LinearPredictionGuider)
|
|
):
|
|
additional_model_inputs[k] = torch.zeros(
|
|
num_samples[0] * 2, num_samples[1]
|
|
).to("cuda")
|
|
else:
|
|
additional_model_inputs[k] = torch.zeros(num_samples).to(
|
|
"cuda"
|
|
)
|
|
else:
|
|
additional_model_inputs[k] = batch[k]
|
|
|
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
|
randn = torch.randn(shape).to("cuda")
|
|
|
|
def denoiser(input, sigma, c):
|
|
return model.denoiser(
|
|
model.model, input, sigma, c, **additional_model_inputs
|
|
)
|
|
|
|
load_model(model.denoiser)
|
|
load_model(model.model)
|
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
|
unload_model(model.model)
|
|
unload_model(model.denoiser)
|
|
|
|
load_model(model.first_stage_model)
|
|
model.en_and_decode_n_samples_a_time = (
|
|
decoding_t # Decode n frames at a time
|
|
)
|
|
samples_x = model.decode_first_stage(samples_z)
|
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
|
unload_model(model.first_stage_model)
|
|
|
|
if filter is not None:
|
|
samples = filter(samples)
|
|
|
|
if T is None:
|
|
grid = torch.stack([samples])
|
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
|
outputs.image(grid.cpu().numpy())
|
|
else:
|
|
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
|
|
for i, vid in enumerate(as_vids):
|
|
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
|
|
st.image(
|
|
grid.cpu().numpy(),
|
|
f"Sample #{i} as image",
|
|
)
|
|
|
|
if return_latents:
|
|
return samples, samples_z
|
|
return samples
|
|
|
|
|
|
def get_batch(
|
|
keys,
|
|
value_dict: dict,
|
|
N: Union[List, ListConfig],
|
|
device: str = "cuda",
|
|
T: int = None,
|
|
additional_batch_uc_fields: List[str] = [],
|
|
):
|
|
# Hardcoded demo setups; might undergo some changes in the future
|
|
|
|
batch = {}
|
|
batch_uc = {}
|
|
|
|
for key in keys:
|
|
if key == "txt":
|
|
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
|
|
|
|
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
|
|
|
|
elif key == "original_size_as_tuple":
|
|
batch["original_size_as_tuple"] = (
|
|
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
|
.to(device)
|
|
.repeat(math.prod(N), 1)
|
|
)
|
|
elif key == "crop_coords_top_left":
|
|
batch["crop_coords_top_left"] = (
|
|
torch.tensor(
|
|
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
|
)
|
|
.to(device)
|
|
.repeat(math.prod(N), 1)
|
|
)
|
|
elif key == "aesthetic_score":
|
|
batch["aesthetic_score"] = (
|
|
torch.tensor([value_dict["aesthetic_score"]])
|
|
.to(device)
|
|
.repeat(math.prod(N), 1)
|
|
)
|
|
batch_uc["aesthetic_score"] = (
|
|
torch.tensor([value_dict["negative_aesthetic_score"]])
|
|
.to(device)
|
|
.repeat(math.prod(N), 1)
|
|
)
|
|
|
|
elif key == "target_size_as_tuple":
|
|
batch["target_size_as_tuple"] = (
|
|
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
|
.to(device)
|
|
.repeat(math.prod(N), 1)
|
|
)
|
|
elif key == "fps":
|
|
batch[key] = (
|
|
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
|
)
|
|
elif key == "fps_id":
|
|
batch[key] = (
|
|
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
|
)
|
|
elif key == "motion_bucket_id":
|
|
batch[key] = (
|
|
torch.tensor([value_dict["motion_bucket_id"]])
|
|
.to(device)
|
|
.repeat(math.prod(N))
|
|
)
|
|
elif key == "pool_image":
|
|
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
|
device, dtype=torch.half
|
|
)
|
|
elif key == "cond_aug":
|
|
batch[key] = repeat(
|
|
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
|
"1 -> b",
|
|
b=math.prod(N),
|
|
)
|
|
elif key == "cond_frames":
|
|
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
|
elif key == "cond_frames_without_noise":
|
|
batch[key] = repeat(
|
|
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
|
)
|
|
else:
|
|
batch[key] = value_dict[key]
|
|
|
|
if T is not None:
|
|
batch["num_video_frames"] = T
|
|
|
|
for key in batch.keys():
|
|
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
|
batch_uc[key] = torch.clone(batch[key])
|
|
elif key in additional_batch_uc_fields and key not in batch_uc:
|
|
batch_uc[key] = copy.copy(batch[key])
|
|
return batch, batch_uc
|
|
|
|
|
|
@torch.no_grad()
|
|
def do_img2img(
|
|
img,
|
|
model,
|
|
sampler,
|
|
value_dict,
|
|
num_samples,
|
|
force_uc_zero_embeddings: Optional[List] = None,
|
|
force_cond_zero_embeddings: Optional[List] = None,
|
|
additional_kwargs={},
|
|
offset_noise_level: int = 0.0,
|
|
return_latents=False,
|
|
skip_encode=False,
|
|
filter=None,
|
|
add_noise=True,
|
|
):
|
|
st.text("Sampling")
|
|
|
|
outputs = st.empty()
|
|
precision_scope = autocast
|
|
with torch.no_grad():
|
|
with precision_scope("cuda"):
|
|
with model.ema_scope():
|
|
load_model(model.conditioner)
|
|
batch, batch_uc = get_batch(
|
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
value_dict,
|
|
[num_samples],
|
|
)
|
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
|
batch,
|
|
batch_uc=batch_uc,
|
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
|
)
|
|
unload_model(model.conditioner)
|
|
for k in c:
|
|
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
|
|
|
|
for k in additional_kwargs:
|
|
c[k] = uc[k] = additional_kwargs[k]
|
|
if skip_encode:
|
|
z = img
|
|
else:
|
|
load_model(model.first_stage_model)
|
|
z = model.encode_first_stage(img)
|
|
unload_model(model.first_stage_model)
|
|
|
|
noise = torch.randn_like(z)
|
|
|
|
sigmas = sampler.discretization(sampler.num_steps).cuda()
|
|
sigma = sigmas[0]
|
|
|
|
st.info(f"all sigmas: {sigmas}")
|
|
st.info(f"noising sigma: {sigma}")
|
|
if offset_noise_level > 0.0:
|
|
noise = noise + offset_noise_level * append_dims(
|
|
torch.randn(z.shape[0], device=z.device), z.ndim
|
|
)
|
|
if add_noise:
|
|
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
|
|
noised_z = noised_z / torch.sqrt(
|
|
1.0 + sigmas[0] ** 2.0
|
|
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
|
|
else:
|
|
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
|
|
|
def denoiser(x, sigma, c):
|
|
return model.denoiser(model.model, x, sigma, c)
|
|
|
|
load_model(model.denoiser)
|
|
load_model(model.model)
|
|
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
|
unload_model(model.model)
|
|
unload_model(model.denoiser)
|
|
|
|
load_model(model.first_stage_model)
|
|
samples_x = model.decode_first_stage(samples_z)
|
|
unload_model(model.first_stage_model)
|
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
|
|
|
if filter is not None:
|
|
samples = filter(samples)
|
|
|
|
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
|
outputs.image(grid.cpu().numpy())
|
|
if return_latents:
|
|
return samples, samples_z
|
|
return samples
|
|
|
|
|
|
def get_resizing_factor(
|
|
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
|
) -> float:
|
|
r_bound = desired_shape[1] / desired_shape[0]
|
|
aspect_r = current_shape[1] / current_shape[0]
|
|
if r_bound >= 1.0:
|
|
if aspect_r >= r_bound:
|
|
factor = min(desired_shape) / min(current_shape)
|
|
else:
|
|
if aspect_r < 1.0:
|
|
factor = max(desired_shape) / min(current_shape)
|
|
else:
|
|
factor = max(desired_shape) / max(current_shape)
|
|
else:
|
|
if aspect_r <= r_bound:
|
|
factor = min(desired_shape) / min(current_shape)
|
|
else:
|
|
if aspect_r > 1:
|
|
factor = max(desired_shape) / min(current_shape)
|
|
else:
|
|
factor = max(desired_shape) / max(current_shape)
|
|
|
|
return factor
|
|
|
|
|
|
def get_interactive_image(key=None) -> Image.Image:
|
|
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
|
if image is not None:
|
|
image = Image.open(image)
|
|
if not image.mode == "RGB":
|
|
image = image.convert("RGB")
|
|
return image
|
|
|
|
|
|
def load_img_for_prediction(
|
|
W: int, H: int, display=True, key=None, device="cuda"
|
|
) -> torch.Tensor:
|
|
image = get_interactive_image(key=key)
|
|
if image is None:
|
|
return None
|
|
if display:
|
|
st.image(image)
|
|
w, h = image.size
|
|
|
|
image = np.array(image).transpose(2, 0, 1)
|
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
|
image = image.unsqueeze(0)
|
|
|
|
rfs = get_resizing_factor((H, W), (h, w))
|
|
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
|
top = (resize_size[0] - H) // 2
|
|
left = (resize_size[1] - W) // 2
|
|
|
|
image = torch.nn.functional.interpolate(
|
|
image, resize_size, mode="area", antialias=False
|
|
)
|
|
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
|
|
|
if display:
|
|
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
|
|
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
|
|
st.image(pil_image)
|
|
return image.to(device) * 2.0 - 1.0
|
|
|
|
|
|
def save_video_as_grid_and_mp4(
|
|
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
|
|
):
|
|
os.makedirs(save_path, exist_ok=True)
|
|
base_count = len(glob(os.path.join(save_path, "*.mp4")))
|
|
|
|
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
|
|
video_batch = embed_watermark(video_batch)
|
|
for vid in video_batch:
|
|
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
|
|
|
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
|
|
|
writer = cv2.VideoWriter(
|
|
video_path,
|
|
cv2.VideoWriter_fourcc(*"MP4V"),
|
|
fps,
|
|
(vid.shape[-1], vid.shape[-2]),
|
|
)
|
|
|
|
vid = (
|
|
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
|
)
|
|
for frame in vid:
|
|
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
|
writer.write(frame)
|
|
|
|
writer.release()
|
|
|
|
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
|
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
|
|
|
|
with open(video_path_h264, "rb") as f:
|
|
video_bytes = f.read()
|
|
st.video(video_bytes)
|
|
|
|
base_count += 1
|