SV3D inference code (#300)

* Makes init changes for SV3D

* Small fixes : cond_aug

* Fixes SV3D checkpoint, fixes rembg

* Black formatting

* Adds streamlit demo, fixes simple sample script

* Removes SV3D video_decoder, keeps SV3D image_decoder

* Updates README

* Minor updates

* Remove GSO script

---------

Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
Vikram Voleti
2024-03-18 23:03:02 +05:30
committed by GitHub
parent c51e4e30c2
commit b4b7b644a1
15 changed files with 937 additions and 85 deletions

View File

@@ -1,8 +1,10 @@
import os
import sys
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.demo.sv3d_helpers import *
SAVE_PATH = "outputs/demo/vid/"
@@ -87,11 +89,51 @@ VERSION2SPECS = {
"decoding_t": 14,
},
},
"sv3d_u": {
"T": 21,
"H": 576,
"W": 576,
"C": 4,
"f": 8,
"config": "configs/inference/sv3d_u.yaml",
"ckpt": "checkpoints/sv3d_u.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 3,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 50,
"decoding_t": 14,
},
},
"sv3d_p": {
"T": 21,
"H": 576,
"W": 576,
"C": 4,
"f": 8,
"config": "configs/inference/sv3d_p.yaml",
"ckpt": "checkpoints/sv3d_p.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 3,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 50,
"decoding_t": 14,
},
},
}
if __name__ == "__main__":
st.title("Stable Video Diffusion")
st.title("Stable Video Diffusion / SV3D")
version = st.selectbox(
"Model Version",
[k for k in VERSION2SPECS.keys()],
@@ -131,17 +173,42 @@ if __name__ == "__main__":
{},
)
if "fps" not in ukeys:
value_dict["fps"] = 10
value_dict["image_only_indicator"] = 0
if mode == "img2vid":
img = load_img_for_prediction(W, H)
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
if "sv3d" in version:
cond_aug = 1e-5
else:
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
value_dict["cond_frames_without_noise"] = img
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
value_dict["cond_aug"] = cond_aug
if "sv3d_p" in version:
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
trajectory = st.selectbox(
"Trajectory",
["same elevation", "dynamic"],
0,
)
if trajectory == "same elevation":
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
elif trajectory == "dynamic":
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
value_dict["azimuths_rad"] = azim_rad
elif "sv3d_u" in version:
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
seed = st.sidebar.number_input(
"seed", value=23, min_value=0, max_value=int(1e9)
)
@@ -151,6 +218,19 @@ if __name__ == "__main__":
os.path.join(SAVE_PATH, version), init_value=True
)
if "sv3d" in version:
plot_save_path = os.path.join(save_path, "plot_3D.png")
plot_3D(
azim=value_dict["azimuths_rad"],
polar=value_dict["polars_rad"],
save_path=plot_save_path,
dynamic=("sv3d_p" in version),
)
st.image(
plot_save_path,
f"3D camera trajectory",
)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling(options=options)