mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 22:34:22 +01:00
Stable Video Diffusion
This commit is contained in:
200
scripts/demo/video_sampling.py
Normal file
200
scripts/demo/video_sampling.py
Normal file
@@ -0,0 +1,200 @@
|
||||
import os
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
|
||||
SAVE_PATH = "outputs/demo/vid/"
|
||||
|
||||
VERSION2SPECS = {
|
||||
"svd": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_image_decoder": {
|
||||
"T": 14,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 25,
|
||||
},
|
||||
},
|
||||
"svd_xt": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd.yaml",
|
||||
"ckpt": "checkpoints/svd_xt.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
"svd_xt_image_decoder": {
|
||||
"T": 25,
|
||||
"H": 576,
|
||||
"W": 1024,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/svd_image_decoder.yaml",
|
||||
"ckpt": "checkpoints/svd_xt_image_decoder.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 3.0,
|
||||
"min_cfg": 1.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 2,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 30,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Video Diffusion")
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[k for k in VERSION2SPECS.keys()],
|
||||
0,
|
||||
)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
if st.checkbox("Load Model"):
|
||||
mode = "img2vid"
|
||||
else:
|
||||
mode = "skip"
|
||||
|
||||
H = st.sidebar.number_input(
|
||||
"H", value=version_dict["H"], min_value=64, max_value=2048
|
||||
)
|
||||
W = st.sidebar.number_input(
|
||||
"W", value=version_dict["W"], min_value=64, max_value=2048
|
||||
)
|
||||
T = st.sidebar.number_input(
|
||||
"T", value=version_dict["T"], min_value=0, max_value=128
|
||||
)
|
||||
C = version_dict["C"]
|
||||
F = version_dict["f"]
|
||||
options = version_dict["options"]
|
||||
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
ukeys = set(
|
||||
get_unique_embedder_keys_from_conditioner(state["model"].conditioner)
|
||||
)
|
||||
|
||||
value_dict = init_embedder_options(
|
||||
ukeys,
|
||||
{},
|
||||
)
|
||||
|
||||
value_dict["image_only_indicator"] = 0
|
||||
|
||||
if mode == "img2vid":
|
||||
img = load_img_for_prediction(W, H)
|
||||
cond_aug = st.number_input(
|
||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||
)
|
||||
value_dict["cond_frames_without_noise"] = img
|
||||
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
|
||||
seed = st.sidebar.number_input(
|
||||
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(
|
||||
os.path.join(SAVE_PATH, version), init_value=True
|
||||
)
|
||||
|
||||
options["num_frames"] = T
|
||||
|
||||
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||
num_samples = num_rows * num_cols
|
||||
|
||||
decoding_t = st.number_input(
|
||||
"Decode t frames at a time (set small if you are low on VRAM)",
|
||||
value=options.get("decoding_t", T),
|
||||
min_value=1,
|
||||
max_value=int(1e9),
|
||||
)
|
||||
|
||||
if st.checkbox("Overwrite fps in mp4 generator", False):
|
||||
saving_fps = st.number_input(
|
||||
f"saving video at fps:", value=value_dict["fps"], min_value=1
|
||||
)
|
||||
else:
|
||||
saving_fps = value_dict["fps"]
|
||||
|
||||
if st.button("Sample"):
|
||||
out = do_sample(
|
||||
model,
|
||||
sampler,
|
||||
value_dict,
|
||||
num_samples,
|
||||
H,
|
||||
W,
|
||||
C,
|
||||
F,
|
||||
T=T,
|
||||
batch2model_input=["num_video_frames", "image_only_indicator"],
|
||||
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
|
||||
force_cond_zero_embeddings=options.get(
|
||||
"force_cond_zero_embeddings", None
|
||||
),
|
||||
return_latents=False,
|
||||
decoding_t=decoding_t,
|
||||
)
|
||||
|
||||
if isinstance(out, (tuple, list)):
|
||||
samples, samples_z = out
|
||||
else:
|
||||
samples = out
|
||||
samples_z = None
|
||||
|
||||
if save_locally:
|
||||
save_video_as_grid_and_mp4(samples, save_path, T, fps=saving_fps)
|
||||
Reference in New Issue
Block a user