mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
Stable Video Diffusion
This commit is contained in:
59
scripts/demo/discretization.py
Normal file
59
scripts/demo/discretization.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
from sgm.modules.diffusionmodules.discretizer import Discretization
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization: Discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, discretization: Discretization, strength: float = 0.0, original_steps=None
|
||||
):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
@@ -253,7 +253,10 @@ if __name__ == "__main__":
|
||||
st.title("Stable Diffusion")
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
if st.checkbox("Load Model"):
|
||||
mode = st.radio("Mode", ("txt2img", "img2img"), 0)
|
||||
else:
|
||||
mode = "skip"
|
||||
st.write("__________________________")
|
||||
|
||||
set_lowvram_mode(st.checkbox("Low vram mode", True))
|
||||
@@ -269,10 +272,11 @@ if __name__ == "__main__":
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
|
||||
@@ -333,6 +337,8 @@ if __name__ == "__main__":
|
||||
filter=state.get("filter"),
|
||||
stage2strength=stage2strength,
|
||||
)
|
||||
elif mode == "skip":
|
||||
out = None
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode}")
|
||||
if isinstance(out, (tuple, list)):
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
from typing import List, Union
|
||||
from glob import glob
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import streamlit as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as TT
|
||||
from einops import rearrange, repeat
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig, OmegaConf
|
||||
@@ -12,63 +17,22 @@ from PIL import Image
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
from torch import autocast
|
||||
from torchvision import transforms
|
||||
from torchvision.utils import make_grid
|
||||
from torchvision.utils import make_grid, save_image
|
||||
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
|
||||
from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import append_dims, instantiate_from_config
|
||||
|
||||
|
||||
class WatermarkEmbedder:
|
||||
def __init__(self, watermark):
|
||||
self.watermark = watermark
|
||||
self.num_bits = len(WATERMARK_BITS)
|
||||
self.encoder = WatermarkEncoder()
|
||||
self.encoder.set_watermark("bits", self.watermark)
|
||||
|
||||
def __call__(self, image: torch.Tensor):
|
||||
"""
|
||||
Adds a predefined watermark to the input image
|
||||
|
||||
Args:
|
||||
image: ([N,] B, C, H, W) in range [0, 1]
|
||||
|
||||
Returns:
|
||||
same as input but watermarked
|
||||
"""
|
||||
# watermarking libary expects input as cv2 BGR format
|
||||
squeeze = len(image.shape) == 4
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
image = torch.from_numpy(
|
||||
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
|
||||
).to(image.device)
|
||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||
if squeeze:
|
||||
image = image[0]
|
||||
return image
|
||||
|
||||
|
||||
# A fixed 48-bit message that was choosen at random
|
||||
# WATERMARK_MESSAGE = 0xB3EC907BB19E
|
||||
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
|
||||
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
|
||||
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
|
||||
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
|
||||
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper)
|
||||
from scripts.util.detection.nsfw_and_watermark_dectection import \
|
||||
DeepFloydDataFiltering
|
||||
from sgm.inference.helpers import embed_watermark
|
||||
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
|
||||
VanillaCFG)
|
||||
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
|
||||
DPMPP2SAncestralSampler,
|
||||
EulerAncestralSampler,
|
||||
EulerEDMSampler,
|
||||
HeunEDMSampler,
|
||||
LinearMultistepSampler)
|
||||
from sgm.util import append_dims, default, instantiate_from_config
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -164,11 +128,12 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
if prompt is None:
|
||||
prompt = st.text_input(
|
||||
"Prompt", "A professional photograph of an astronaut riding a pig"
|
||||
)
|
||||
prompt = "A professional photograph of an astronaut riding a pig"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = st.text_input("Negative prompt", "")
|
||||
negative_prompt = ""
|
||||
|
||||
prompt = st.text_input("Prompt", prompt)
|
||||
negative_prompt = st.text_input("Negative prompt", negative_prompt)
|
||||
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = negative_prompt
|
||||
@@ -203,13 +168,35 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
if key in ["fps_id", "fps"]:
|
||||
fps = st.number_input("fps", value=6, min_value=1)
|
||||
|
||||
value_dict["fps"] = fps
|
||||
value_dict["fps_id"] = fps - 1
|
||||
|
||||
if key == "motion_bucket_id":
|
||||
mb_id = st.number_input("motion bucket id", 0, 511, value=127)
|
||||
value_dict["motion_bucket_id"] = mb_id
|
||||
|
||||
if key == "pool_image":
|
||||
st.text("Image for pool conditioning")
|
||||
image = load_img(
|
||||
key="pool_image_input",
|
||||
size=224,
|
||||
center_crop=True,
|
||||
)
|
||||
if image is None:
|
||||
st.info("Need an image here")
|
||||
image = torch.zeros(1, 3, 224, 224)
|
||||
value_dict["pool_image"] = image
|
||||
|
||||
return value_dict
|
||||
|
||||
|
||||
def perform_save_locally(save_path, samples):
|
||||
os.makedirs(os.path.join(save_path), exist_ok=True)
|
||||
base_count = len(os.listdir(os.path.join(save_path)))
|
||||
samples = embed_watemark(samples)
|
||||
samples = embed_watermark(samples)
|
||||
for sample in samples:
|
||||
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
|
||||
Image.fromarray(sample.astype(np.uint8)).save(
|
||||
@@ -228,95 +215,99 @@ def init_save_locally(_dir, init_value: bool = False):
|
||||
return save_locally, save_path
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 1.0):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
||||
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
class Txt2NoisyDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
params:
|
||||
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
||||
"""
|
||||
|
||||
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
|
||||
self.discretization = discretization
|
||||
self.strength = strength
|
||||
self.original_steps = original_steps
|
||||
assert 0.0 <= self.strength <= 1.0
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# sigmas start large first, and decrease then
|
||||
sigmas = self.discretization(*args, **kwargs)
|
||||
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
if self.original_steps is None:
|
||||
steps = len(sigmas)
|
||||
else:
|
||||
steps = self.original_steps + 1
|
||||
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
||||
sigmas = sigmas[prune_index:]
|
||||
print("prune index:", prune_index)
|
||||
sigmas = torch.flip(sigmas, (0,))
|
||||
print(f"sigmas after pruning: ", sigmas)
|
||||
return sigmas
|
||||
|
||||
|
||||
def get_guider(key):
|
||||
def get_guider(options, key):
|
||||
guider = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
[
|
||||
"VanillaCFG",
|
||||
"IdentityGuider",
|
||||
"LinearPredictionGuider",
|
||||
],
|
||||
options.get("guider", 0),
|
||||
)
|
||||
|
||||
additional_guider_kwargs = options.pop("additional_guider_kwargs", {})
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
scale_schedule = st.sidebar.selectbox(
|
||||
f"Scale schedule #{key}",
|
||||
["Identity", "Oscillating"],
|
||||
)
|
||||
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
[
|
||||
"None",
|
||||
],
|
||||
)
|
||||
if scale_schedule == "Identity":
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}",
|
||||
value=options.get("cfg", 5.0),
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
if thresholder == "None":
|
||||
dyn_thresh_config = {
|
||||
"target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding"
|
||||
scale_schedule_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
|
||||
"params": {"scale": scale},
|
||||
}
|
||||
|
||||
elif scale_schedule == "Oscillating":
|
||||
small_scale = st.number_input(
|
||||
f"small cfg-scale #{key}",
|
||||
value=4.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
large_scale = st.number_input(
|
||||
f"large cfg-scale #{key}",
|
||||
value=16.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
sigma_cutoff = st.number_input(
|
||||
f"sigma cutoff #{key}",
|
||||
value=1.0,
|
||||
min_value=0.0,
|
||||
)
|
||||
|
||||
scale_schedule_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
|
||||
"params": {
|
||||
"small_scale": small_scale,
|
||||
"large_scale": large_scale,
|
||||
"sigma_cutoff": sigma_cutoff,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
|
||||
"params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config},
|
||||
"params": {
|
||||
"scale_schedule_config": scale_schedule_config,
|
||||
**additional_guider_kwargs,
|
||||
},
|
||||
}
|
||||
elif guider == "LinearPredictionGuider":
|
||||
max_scale = st.number_input(
|
||||
f"max-cfg-scale #{key}",
|
||||
value=options.get("cfg", 1.5),
|
||||
min_value=1.0,
|
||||
)
|
||||
min_scale = st.number_input(
|
||||
f"min guidance scale",
|
||||
value=options.get("min_cfg", 1.0),
|
||||
min_value=1.0,
|
||||
max_value=10.0,
|
||||
)
|
||||
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
|
||||
"params": {
|
||||
"max_scale": max_scale,
|
||||
"min_scale": min_scale,
|
||||
"num_frames": options["num_frames"],
|
||||
**additional_guider_kwargs,
|
||||
},
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError
|
||||
@@ -325,18 +316,21 @@ def get_guider(key):
|
||||
|
||||
def init_sampling(
|
||||
key=1,
|
||||
img2img_strength=1.0,
|
||||
specify_num_samples=True,
|
||||
stage2strength=None,
|
||||
img2img_strength: Optional[float] = None,
|
||||
specify_num_samples: bool = True,
|
||||
stage2strength: Optional[float] = None,
|
||||
options: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
options = {} if options is None else options
|
||||
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
@@ -348,7 +342,7 @@ def init_sampling(
|
||||
"DPMPP2MSampler",
|
||||
"LinearMultistepSampler",
|
||||
],
|
||||
0,
|
||||
options.get("sampler", 0),
|
||||
)
|
||||
discretization = st.sidebar.selectbox(
|
||||
f"Discretization #{key}",
|
||||
@@ -356,14 +350,15 @@ def init_sampling(
|
||||
"LegacyDDPMDiscretization",
|
||||
"EDMDiscretization",
|
||||
],
|
||||
options.get("discretization", 0),
|
||||
)
|
||||
|
||||
discretization_config = get_discretization(discretization, key=key)
|
||||
discretization_config = get_discretization(discretization, options=options, key=key)
|
||||
|
||||
guider_config = get_guider(key=key)
|
||||
guider_config = get_guider(options=options, key=key)
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
if img2img_strength is not None:
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
)
|
||||
@@ -377,15 +372,19 @@ def init_sampling(
|
||||
return sampler, num_rows, num_cols
|
||||
|
||||
|
||||
def get_discretization(discretization, key=1):
|
||||
def get_discretization(discretization, options, key=1):
|
||||
if discretization == "LegacyDDPMDiscretization":
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||
}
|
||||
elif discretization == "EDMDiscretization":
|
||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||
sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=3.0)
|
||||
sigma_min = st.number_input(
|
||||
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
|
||||
) # 0.0292
|
||||
sigma_max = st.number_input(
|
||||
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
|
||||
) # 14.6146
|
||||
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
|
||||
discretization_config = {
|
||||
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
|
||||
"params": {
|
||||
@@ -474,8 +473,8 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
|
||||
return sampler
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
def get_interactive_image() -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"])
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
@@ -483,8 +482,12 @@ def get_interactive_image(key=None) -> Image.Image:
|
||||
return image
|
||||
|
||||
|
||||
def load_img(display=True, key=None):
|
||||
image = get_interactive_image(key=key)
|
||||
def load_img(
|
||||
display: bool = True,
|
||||
size: Union[None, int, Tuple[int, int]] = None,
|
||||
center_crop: bool = False,
|
||||
):
|
||||
image = get_interactive_image()
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
@@ -492,12 +495,15 @@ def load_img(display=True, key=None):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
|
||||
transform = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Lambda(lambda x: x * 2.0 - 1.0),
|
||||
]
|
||||
)
|
||||
transform = []
|
||||
if size is not None:
|
||||
transform.append(transforms.Resize(size))
|
||||
if center_crop:
|
||||
transform.append(transforms.CenterCrop(size))
|
||||
transform.append(transforms.ToTensor())
|
||||
transform.append(transforms.Lambda(lambda x: 2.0 * x - 1.0))
|
||||
|
||||
transform = transforms.Compose(transform)
|
||||
img = transform(image)[None, ...]
|
||||
st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}")
|
||||
return img
|
||||
@@ -518,15 +524,18 @@ def do_sample(
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
force_uc_zero_embeddings: List = None,
|
||||
force_uc_zero_embeddings: Optional[List] = None,
|
||||
force_cond_zero_embeddings: Optional[List] = None,
|
||||
batch2model_input: List = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
T=None,
|
||||
additional_batch_uc_fields=None,
|
||||
decoding_t=None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
||||
batch2model_input = default(batch2model_input, [])
|
||||
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
|
||||
|
||||
st.text("Sampling")
|
||||
|
||||
@@ -535,24 +544,25 @@ def do_sample(
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
if T is not None:
|
||||
num_samples = [num_samples, T]
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_model(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
num_samples,
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor):
|
||||
print(key, batch[key].shape)
|
||||
elif isinstance(batch[key], list):
|
||||
print(key, [len(l) for l in batch[key]])
|
||||
else:
|
||||
print(key, batch[key])
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
|
||||
@@ -561,10 +571,29 @@ def do_sample(
|
||||
c[k], uc[k] = map(
|
||||
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
|
||||
)
|
||||
if k in ["crossattn", "concat"] and T is not None:
|
||||
uc[k] = repeat(uc[k], "b ... -> b t ...", t=T)
|
||||
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=T)
|
||||
c[k] = repeat(c[k], "b ... -> b t ...", t=T)
|
||||
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=T)
|
||||
|
||||
additional_model_inputs = {}
|
||||
for k in batch2model_input:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
if k == "image_only_indicator":
|
||||
assert T is not None
|
||||
|
||||
if isinstance(
|
||||
sampler.guider, (VanillaCFG, LinearPredictionGuider)
|
||||
):
|
||||
additional_model_inputs[k] = torch.zeros(
|
||||
num_samples[0] * 2, num_samples[1]
|
||||
).to("cuda")
|
||||
else:
|
||||
additional_model_inputs[k] = torch.zeros(num_samples).to(
|
||||
"cuda"
|
||||
)
|
||||
else:
|
||||
additional_model_inputs[k] = batch[k]
|
||||
|
||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||
randn = torch.randn(shape).to("cuda")
|
||||
@@ -581,6 +610,9 @@ def do_sample(
|
||||
unload_model(model.denoiser)
|
||||
|
||||
load_model(model.first_stage_model)
|
||||
model.en_and_decode_n_samples_a_time = (
|
||||
decoding_t # Decode n frames at a time
|
||||
)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_model(model.first_stage_model)
|
||||
@@ -588,16 +620,32 @@ def do_sample(
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if T is None:
|
||||
grid = torch.stack([samples])
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
else:
|
||||
as_vids = rearrange(samples, "(b t) c h w -> b t c h w", t=T)
|
||||
for i, vid in enumerate(as_vids):
|
||||
grid = rearrange(make_grid(vid, nrow=4), "c h w -> h w c")
|
||||
st.image(
|
||||
grid.cpu().numpy(),
|
||||
f"Sample #{i} as image",
|
||||
)
|
||||
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
def get_batch(
|
||||
keys,
|
||||
value_dict: dict,
|
||||
N: Union[List, ListConfig],
|
||||
device: str = "cuda",
|
||||
T: int = None,
|
||||
additional_batch_uc_fields: List[str] = [],
|
||||
):
|
||||
# Hardcoded demo setups; might undergo some changes in the future
|
||||
|
||||
batch = {}
|
||||
@@ -605,21 +653,15 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
|
||||
|
||||
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
|
||||
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
@@ -627,30 +669,67 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "aesthetic_score":
|
||||
batch["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
batch["target_size_as_tuple"] = (
|
||||
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
.repeat(math.prod(N), 1)
|
||||
)
|
||||
elif key == "fps":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "fps_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
|
||||
)
|
||||
elif key == "motion_bucket_id":
|
||||
batch[key] = (
|
||||
torch.tensor([value_dict["motion_bucket_id"]])
|
||||
.to(device)
|
||||
.repeat(math.prod(N))
|
||||
)
|
||||
elif key == "pool_image":
|
||||
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
|
||||
device, dtype=torch.half
|
||||
)
|
||||
elif key == "cond_aug":
|
||||
batch[key] = repeat(
|
||||
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
|
||||
"1 -> b",
|
||||
b=math.prod(N),
|
||||
)
|
||||
elif key == "cond_frames":
|
||||
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
|
||||
elif key == "cond_frames_without_noise":
|
||||
batch[key] = repeat(
|
||||
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
|
||||
)
|
||||
else:
|
||||
batch[key] = value_dict[key]
|
||||
|
||||
if T is not None:
|
||||
batch["num_video_frames"] = T
|
||||
|
||||
for key in batch.keys():
|
||||
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
|
||||
batch_uc[key] = torch.clone(batch[key])
|
||||
elif key in additional_batch_uc_fields and key not in batch_uc:
|
||||
batch_uc[key] = copy.copy(batch[key])
|
||||
return batch, batch_uc
|
||||
|
||||
|
||||
@@ -661,7 +740,8 @@ def do_img2img(
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
force_uc_zero_embeddings=[],
|
||||
force_uc_zero_embeddings: Optional[List] = None,
|
||||
force_cond_zero_embeddings: Optional[List] = None,
|
||||
additional_kwargs={},
|
||||
offset_noise_level: int = 0.0,
|
||||
return_latents=False,
|
||||
@@ -686,6 +766,7 @@ def do_img2img(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_model(model.conditioner)
|
||||
for k in c:
|
||||
@@ -736,9 +817,112 @@ def do_img2img(
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
|
||||
grid = embed_watemark(torch.stack([samples]))
|
||||
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
|
||||
outputs.image(grid.cpu().numpy())
|
||||
if return_latents:
|
||||
return samples, samples_z
|
||||
return samples
|
||||
|
||||
|
||||
def get_resizing_factor(
|
||||
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||
) -> float:
|
||||
r_bound = desired_shape[1] / desired_shape[0]
|
||||
aspect_r = current_shape[1] / current_shape[0]
|
||||
if r_bound >= 1.0:
|
||||
if aspect_r >= r_bound:
|
||||
factor = min(desired_shape) / min(current_shape)
|
||||
else:
|
||||
if aspect_r < 1.0:
|
||||
factor = max(desired_shape) / min(current_shape)
|
||||
else:
|
||||
factor = max(desired_shape) / max(current_shape)
|
||||
else:
|
||||
if aspect_r <= r_bound:
|
||||
factor = min(desired_shape) / min(current_shape)
|
||||
else:
|
||||
if aspect_r > 1:
|
||||
factor = max(desired_shape) / min(current_shape)
|
||||
else:
|
||||
factor = max(desired_shape) / max(current_shape)
|
||||
|
||||
return factor
|
||||
|
||||
|
||||
def get_interactive_image(key=None) -> Image.Image:
|
||||
image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key)
|
||||
if image is not None:
|
||||
image = Image.open(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
return image
|
||||
|
||||
|
||||
def load_img_for_prediction(
|
||||
W: int, H: int, display=True, key=None, device="cuda"
|
||||
) -> torch.Tensor:
|
||||
image = get_interactive_image(key=key)
|
||||
if image is None:
|
||||
return None
|
||||
if display:
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
|
||||
image = np.array(image).transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
rfs = get_resizing_factor((H, W), (h, w))
|
||||
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
||||
top = (resize_size[0] - H) // 2
|
||||
left = (resize_size[1] - W) // 2
|
||||
|
||||
image = torch.nn.functional.interpolate(
|
||||
image, resize_size, mode="area", antialias=False
|
||||
)
|
||||
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
||||
|
||||
if display:
|
||||
numpy_img = np.transpose(image[0].numpy(), (1, 2, 0))
|
||||
pil_image = Image.fromarray((numpy_img * 255).astype(np.uint8))
|
||||
st.image(pil_image)
|
||||
return image.to(device) * 2.0 - 1.0
|
||||
|
||||
|
||||
def save_video_as_grid_and_mp4(
|
||||
video_batch: torch.Tensor, save_path: str, T: int, fps: int = 5
|
||||
):
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
base_count = len(glob(os.path.join(save_path, "*.mp4")))
|
||||
|
||||
video_batch = rearrange(video_batch, "(b t) c h w -> b t c h w", t=T)
|
||||
video_batch = embed_watermark(video_batch)
|
||||
for vid in video_batch:
|
||||
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
|
||||
|
||||
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
|
||||
|
||||
writer = cv2.VideoWriter(
|
||||
video_path,
|
||||
cv2.VideoWriter_fourcc(*"MP4V"),
|
||||
fps,
|
||||
(vid.shape[-1], vid.shape[-2]),
|
||||
)
|
||||
|
||||
vid = (
|
||||
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
|
||||
)
|
||||
for frame in vid:
|
||||
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
|
||||
writer.write(frame)
|
||||
|
||||
writer.release()
|
||||
|
||||
video_path_h264 = video_path[:-4] + "_h264.mp4"
|
||||
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
|
||||
|
||||
with open(video_path_h264, "rb") as f:
|
||||
video_bytes = f.read()
|
||||
st.video(video_bytes)
|
||||
|
||||
base_count += 1
|
||||
|
||||
200
scripts/demo/video_sampling.py
Normal file
200
scripts/demo/video_sampling.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import os
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
|
||||
SAVE_PATH = "outputs/demo/vid/"
|
||||
|
||||
VERSION2SPECS = {
|
||||
"svd": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_image_decoder": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_xt": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd_xt.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
"svd_xt_image_decoder": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Video Diffusion")
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[k for k in VERSION2SPECS.keys()],
|
||||
0,
|
||||
)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
if st.checkbox("Load Model"):
|
||||
mode = "img2vid"
|
||||
else:
|
||||
mode = "skip"
|
||||
|
||||
H = st.sidebar.number_input(
|
||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||
)
|
||||
W = st.sidebar.number_input(
|
||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||
)
|
||||
T = st.sidebar.number_input(
|
||||
"T", value=version_dict["T"], min_value=0, max_value=128
|
||||
)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
options = version_dict["options"]
|
||||
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
ukeys = set(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
|
||||
)
|
||||
|
||||
value_dict = init_embedder_options(
|
||||
ukeys,
|
||||
{},
|
||||
)
|
||||
|
||||
value_dict["image_only_indicator"] = 0
|
||||
|
||||
if mode == "img2vid":
|
||||
img = load_img_for_prediction(W, H)
|
||||
cond_aug = st.number_input(
|
||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||
)
|
||||
value_dict["cond_frames_without_noise"] = img
|
||||
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
|
||||
seed = st.sidebar.number_input(
|
||||
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(
|
||||
os.path.join(SAVE_PATH, version), init_value=True
|
||||
)
|
||||
|
||||
options["num_frames"] = T
|
||||
|
||||
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
decoding_t = st.number_input(
|
||||
"Decode t frames at a time (set small if you are low on VRAM)",
|
||||
value=options.get("decoding_t", T),
|
||||
min_value=1,
|
||||
max_value=int(1e9),
|
||||
)
|
||||
|
||||
if st.checkbox("Overwrite fps in mp4 generator", False):
|
||||
saving_fps = st.number_input(
|
||||
f"saving video at fps:", value=value_dict["fps"], min_value=1
|
||||
)
|
||||
else:
|
||||
saving_fps = value_dict["fps"]
|
||||
|
||||
if st.button("Sample"):
|
||||
out = do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
T=T,
|
||||
batch2model_input=["num_video_frames", "image_only_indicator"],
|
||||
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
|
||||
force_cond_zero_embeddings=options.get(
|
||||
"force_cond_zero_embeddings", None
|
||||
),
|
||||
return_latents=False,
|
||||
decoding_t=decoding_t,
|
||||
)
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
samples, samples_z = out
|
||||
else:
|
||||
samples = out
|
||||
samples_z = None
|
||||
|
||||
if save_locally:
|
||||
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
|
||||
Reference in New Issue
Block a user