mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
SV3D inference code (#300)
* Makes init changes for SV3D * Small fixes : cond_aug * Fixes SV3D checkpoint, fixes rembg * Black formatting * Adds streamlit demo, fixes simple sample script * Removes SV3D video_decoder, keeps SV3D image_decoder * Updates README * Minor updates * Remove GSO script --------- Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from scripts.demo.streamlit_helpers import *
|
||||
from scripts.demo.sv3d_helpers import *
|
||||
|
||||
SAVE_PATH = "outputs/demo/vid/"
|
||||
|
||||
@@ -87,11 +89,51 @@ VERSION2SPECS = {
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
"sv3d_u": {
|
||||
"T": 21,
|
||||
"H": 576,
|
||||
"W": 576,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/sv3d_u.yaml",
|
||||
"ckpt": "checkpoints/sv3d_u.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 3,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 50,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
"sv3d_p": {
|
||||
"T": 21,
|
||||
"H": 576,
|
||||
"W": 576,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"config": "configs/inference/sv3d_p.yaml",
|
||||
"ckpt": "checkpoints/sv3d_p.safetensors",
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
"guider": 3,
|
||||
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
|
||||
"num_steps": 50,
|
||||
"decoding_t": 14,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Stable Video Diffusion")
|
||||
st.title("Stable Video Diffusion / SV3D")
|
||||
version = st.selectbox(
|
||||
"Model Version",
|
||||
[k for k in VERSION2SPECS.keys()],
|
||||
@@ -131,17 +173,42 @@ if __name__ == "__main__":
|
||||
{},
|
||||
)
|
||||
|
||||
if "fps" not in ukeys:
|
||||
value_dict["fps"] = 10
|
||||
|
||||
value_dict["image_only_indicator"] = 0
|
||||
|
||||
if mode == "img2vid":
|
||||
img = load_img_for_prediction(W, H)
|
||||
cond_aug = st.number_input(
|
||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||
)
|
||||
if "sv3d" in version:
|
||||
cond_aug = 1e-5
|
||||
else:
|
||||
cond_aug = st.number_input(
|
||||
"Conditioning augmentation:", value=0.02, min_value=0.0
|
||||
)
|
||||
value_dict["cond_frames_without_noise"] = img
|
||||
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
|
||||
value_dict["cond_aug"] = cond_aug
|
||||
|
||||
if "sv3d_p" in version:
|
||||
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||
trajectory = st.selectbox(
|
||||
"Trajectory",
|
||||
["same elevation", "dynamic"],
|
||||
0,
|
||||
)
|
||||
if trajectory == "same elevation":
|
||||
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||
elif trajectory == "dynamic":
|
||||
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
|
||||
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
|
||||
value_dict["azimuths_rad"] = azim_rad
|
||||
elif "sv3d_u" in version:
|
||||
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
|
||||
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
|
||||
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
|
||||
|
||||
seed = st.sidebar.number_input(
|
||||
"seed", value=23, min_value=0, max_value=int(1e9)
|
||||
)
|
||||
@@ -151,6 +218,19 @@ if __name__ == "__main__":
|
||||
os.path.join(SAVE_PATH, version), init_value=True
|
||||
)
|
||||
|
||||
if "sv3d" in version:
|
||||
plot_save_path = os.path.join(save_path, "plot_3D.png")
|
||||
plot_3D(
|
||||
azim=value_dict["azimuths_rad"],
|
||||
polar=value_dict["polars_rad"],
|
||||
save_path=plot_save_path,
|
||||
dynamic=("sv3d_p" in version),
|
||||
)
|
||||
st.image(
|
||||
plot_save_path,
|
||||
f"3D camera trajectory",
|
||||
)
|
||||
|
||||
options["num_frames"] = T
|
||||
|
||||
sampler, num_rows, num_cols = init_sampling(options=options)
|
||||
|
||||
Reference in New Issue
Block a user