SV3D inference code (#300)

* Makes init changes for SV3D

* Small fixes : cond_aug

* Fixes SV3D checkpoint, fixes rembg

* Black formatting

* Adds streamlit demo, fixes simple sample script

* Removes SV3D video_decoder, keeps SV3D image_decoder

* Updates README

* Minor updates

* Remove GSO script

---------

Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
Vikram Voleti
2024-03-18 23:03:02 +05:30
committed by GitHub
parent c51e4e30c2
commit b4b7b644a1
15 changed files with 937 additions and 85 deletions

View File

@@ -23,9 +23,11 @@ from PIL import Image
from torchvision.transforms import ToTensor
from scripts.sampling.simple_video_sample import (
get_batch, get_unique_embedder_keys_from_conditioner, load_model)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
get_batch,
get_unique_embedder_keys_from_conditioner,
load_model,
)
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config

View File

@@ -5,6 +5,7 @@ from glob import glob
from typing import Dict, List, Optional, Tuple, Union
import cv2
import imageio
import numpy as np
import streamlit as st
import torch
@@ -15,25 +16,30 @@ from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from scripts.demo.discretization import (
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.modules.diffusionmodules.guiders import (
LinearPredictionGuider,
TrianglePredictionGuider,
VanillaCFG,
)
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims, default, instantiate_from_config
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid, save_image
from scripts.demo.discretization import (Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper)
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.modules.diffusionmodules.guiders import (LinearPredictionGuider,
VanillaCFG)
from sgm.modules.diffusionmodules.sampling import (DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler)
from sgm.util import append_dims, default, instantiate_from_config
@st.cache_resource()
def init_st(version_dict, load_ckpt=True, load_filter=True):
@@ -222,6 +228,7 @@ def get_guider(options, key):
"VanillaCFG",
"IdentityGuider",
"LinearPredictionGuider",
"TrianglePredictionGuider",
],
options.get("guider", 0),
)
@@ -252,7 +259,7 @@ def get_guider(options, key):
value=options.get("cfg", 1.5),
min_value=1.0,
)
min_scale = st.number_input(
min_scale = st.sidebar.number_input(
f"min guidance scale",
value=options.get("min_cfg", 1.0),
min_value=1.0,
@@ -268,6 +275,29 @@ def get_guider(options, key):
**additional_guider_kwargs,
},
}
elif guider == "TrianglePredictionGuider":
max_scale = st.number_input(
f"max-cfg-scale #{key}",
value=options.get("cfg", 2.5),
min_value=1.0,
max_value=10.0,
)
min_scale = st.sidebar.number_input(
f"min guidance scale",
value=options.get("min_cfg", 1.0),
min_value=1.0,
max_value=10.0,
)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider",
"params": {
"max_scale": max_scale,
"min_scale": min_scale,
"num_frames": options["num_frames"],
**additional_guider_kwargs,
},
}
else:
raise NotImplementedError
return guider_config
@@ -288,8 +318,8 @@ def init_sampling(
f"num cols #{key}", value=num_cols, min_value=1, max_value=10
)
steps = st.sidebar.number_input(
f"steps #{key}", value=options.get("num_steps", 40), min_value=1, max_value=1000
steps = st.number_input(
f"steps #{key}", value=options.get("num_steps", 50), min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
@@ -337,13 +367,13 @@ def get_discretization(discretization, options, key=1):
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = st.number_input(
sigma_min = st.sidebar.number_input(
f"sigma_min #{key}", value=options.get("sigma_min", 0.03)
) # 0.0292
sigma_max = st.number_input(
sigma_max = st.sidebar.number_input(
f"sigma_max #{key}", value=options.get("sigma_max", 14.61)
) # 14.6146
rho = st.number_input(f"rho #{key}", value=options.get("rho", 3.0))
rho = st.sidebar.number_input(f"rho #{key}", value=options.get("rho", 3.0))
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
@@ -542,7 +572,12 @@ def do_sample(
assert T is not None
if isinstance(
sampler.guider, (VanillaCFG, LinearPredictionGuider)
sampler.guider,
(
VanillaCFG,
LinearPredictionGuider,
TrianglePredictionGuider,
),
):
additional_model_inputs[k] = torch.zeros(
num_samples[0] * 2, num_samples[1]
@@ -678,6 +713,12 @@ def get_batch(
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
elif key == "polars_rad":
batch[key] = torch.tensor(value_dict["polars_rad"]).to(device).repeat(N[0])
elif key == "azimuths_rad":
batch[key] = (
torch.tensor(value_dict["azimuths_rad"]).to(device).repeat(N[0])
)
else:
batch[key] = value_dict[key]
@@ -827,8 +868,13 @@ def load_img_for_prediction(
st.image(image)
w, h = image.size
image = np.array(image).transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32) / 255.0
image = np.array(image).astype(np.float32) / 255
if image.shape[-1] == 4:
rgb, alpha = image[:, :, :3], image[:, :, 3:]
image = rgb * alpha + (1 - alpha)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)
rfs = get_resizing_factor((H, W), (h, w))
@@ -860,28 +906,16 @@ def save_video_as_grid_and_mp4(
save_image(vid, fp=os.path.join(save_path, f"{base_count:06d}.png"), nrow=4)
video_path = os.path.join(save_path, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps,
(vid.shape[-1], vid.shape[-2]),
)
vid = (
(rearrange(vid, "t c h w -> t h w c") * 255).cpu().numpy().astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
imageio.mimwrite(video_path, vid, fps=fps)
video_path_h264 = video_path[:-4] + "_h264.mp4"
os.system(f"ffmpeg -i {video_path} -c:v libx264 {video_path_h264}")
os.system(f"ffmpeg -i '{video_path}' -c:v libx264 '{video_path_h264}'")
with open(video_path_h264, "rb") as f:
video_bytes = f.read()
os.remove(video_path_h264)
st.video(video_bytes)
base_count += 1

View File

@@ -0,0 +1,104 @@
import os
import matplotlib.pyplot as plt
import numpy as np
def generate_dynamic_cycle_xy_values(
length=21,
init_elev=0,
num_components=84,
frequency_range=(1, 5),
amplitude_range=(0.5, 10),
step_range=(0, 2),
):
# Y values generation
y_sequence = np.ones(length) * init_elev
for _ in range(num_components):
# Choose a frequency that will complete whole cycles in the sequence
frequency = np.random.randint(*frequency_range) * (2 * np.pi / length)
amplitude = np.random.uniform(*amplitude_range)
phase_shift = np.random.choice([0, np.pi]) # np.random.uniform(0, 2 * np.pi)
angles = (
np.linspace(0, frequency * length, length, endpoint=False) + phase_shift
)
y_sequence += np.sin(angles) * amplitude
# X values generation
# Generate length - 1 steps since the last step is back to start
steps = np.random.uniform(*step_range, length - 1)
total_step_sum = np.sum(steps)
# Calculate the scale factor to scale total steps to just under 360
scale_factor = (
360 - ((360 / length) * np.random.uniform(*step_range))
) / total_step_sum
# Apply the scale factor and generate the sequence of X values
x_values = np.cumsum(steps * scale_factor)
# Ensure the sequence starts at 0 and add the final step to complete the loop
x_values = np.insert(x_values, 0, 0)
return x_values, y_sequence
def smooth_data(data, window_size):
# Extend data at both ends by wrapping around to create a continuous loop
pad_size = window_size
padded_data = np.concatenate((data[-pad_size:], data, data[:pad_size]))
# Apply smoothing
kernel = np.ones(window_size) / window_size
smoothed_data = np.convolve(padded_data, kernel, mode="same")
# Extract the smoothed data corresponding to the original sequence
# Adjust the indices to account for the larger padding
start_index = pad_size
end_index = -pad_size if pad_size != 0 else None
smoothed_original_data = smoothed_data[start_index:end_index]
return smoothed_original_data
# Function to generate and process the data
def gen_dynamic_loop(length=21, elev_deg=0):
while True:
# Generate the combined X and Y values using the new function
azim_values, elev_values = generate_dynamic_cycle_xy_values(
length=84, init_elev=elev_deg
)
# Smooth the Y values directly
smoothed_elev_values = smooth_data(elev_values, 5)
max_magnitude = np.max(np.abs(smoothed_elev_values))
if max_magnitude < 90:
break
subsample = 84 // length
azim_rad = np.deg2rad(azim_values[::subsample])
elev_rad = np.deg2rad(smoothed_elev_values[::subsample])
# Make cond frame the last one
return np.roll(azim_rad, -1), np.roll(elev_rad, -1)
def plot_3D(azim, polar, save_path, dynamic=True):
os.makedirs(os.path.dirname(save_path), exist_ok=True)
elev = np.deg2rad(90) - polar
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(projection="3d")
cm = plt.get_cmap("Greys")
col_line = [cm(i) for i in np.linspace(0.3, 1, len(azim) + 1)]
cm = plt.get_cmap("cool")
col = [cm(float(i) / (len(azim))) for i in np.arange(len(azim))]
xs = np.cos(elev) * np.cos(azim)
ys = np.cos(elev) * np.sin(azim)
zs = np.sin(elev)
ax.scatter(xs[0], ys[0], zs[0], s=100, color=col[0])
xs_d, ys_d, zs_d = (xs[1:] - xs[:-1]), (ys[1:] - ys[:-1]), (zs[1:] - zs[:-1])
for i in range(len(xs) - 1):
if dynamic:
ax.quiver(
xs[i], ys[i], zs[i], xs_d[i], ys_d[i], zs_d[i], lw=2, color=col_line[i]
)
else:
ax.plot(xs[i : i + 2], ys[i : i + 2], zs[i : i + 2], lw=2, c=col_line[i])
ax.scatter(xs[i + 1], ys[i + 1], zs[i + 1], s=100, color=col[i + 1])
ax.scatter(xs[:1], ys[:1], zs[:1], s=120, facecolors="none", edgecolors="k")
ax.scatter(xs[-1:], ys[-1:], zs[-1:], s=120, facecolors="none", edgecolors="k")
ax.view_init(elev=30, azim=-20, roll=0)
plt.savefig(save_path, bbox_inches="tight")
plt.clf()
plt.close()

View File

@@ -1,8 +1,10 @@
import os
import sys
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
from pytorch_lightning import seed_everything
from scripts.demo.streamlit_helpers import *
from scripts.demo.sv3d_helpers import *
SAVE_PATH = "outputs/demo/vid/"
@@ -87,11 +89,51 @@ VERSION2SPECS = {
"decoding_t": 14,
},
},
"sv3d_u": {
"T": 21,
"H": 576,
"W": 576,
"C": 4,
"f": 8,
"config": "configs/inference/sv3d_u.yaml",
"ckpt": "checkpoints/sv3d_u.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 3,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 50,
"decoding_t": 14,
},
},
"sv3d_p": {
"T": 21,
"H": 576,
"W": 576,
"C": 4,
"f": 8,
"config": "configs/inference/sv3d_p.yaml",
"ckpt": "checkpoints/sv3d_p.safetensors",
"options": {
"discretization": 1,
"cfg": 2.5,
"sigma_min": 0.002,
"sigma_max": 700.0,
"rho": 7.0,
"guider": 3,
"force_uc_zero_embeddings": ["cond_frames", "cond_frames_without_noise"],
"num_steps": 50,
"decoding_t": 14,
},
},
}
if __name__ == "__main__":
st.title("Stable Video Diffusion")
st.title("Stable Video Diffusion / SV3D")
version = st.selectbox(
"Model Version",
[k for k in VERSION2SPECS.keys()],
@@ -131,17 +173,42 @@ if __name__ == "__main__":
{},
)
if "fps" not in ukeys:
value_dict["fps"] = 10
value_dict["image_only_indicator"] = 0
if mode == "img2vid":
img = load_img_for_prediction(W, H)
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
if "sv3d" in version:
cond_aug = 1e-5
else:
cond_aug = st.number_input(
"Conditioning augmentation:", value=0.02, min_value=0.0
)
value_dict["cond_frames_without_noise"] = img
value_dict["cond_frames"] = img + cond_aug * torch.randn_like(img)
value_dict["cond_aug"] = cond_aug
if "sv3d_p" in version:
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
trajectory = st.selectbox(
"Trajectory",
["same elevation", "dynamic"],
0,
)
if trajectory == "same elevation":
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
elif trajectory == "dynamic":
azim_rad, elev_rad = gen_dynamic_loop(length=21, elev_deg=elev_deg)
value_dict["polars_rad"] = np.deg2rad(90) - elev_rad
value_dict["azimuths_rad"] = azim_rad
elif "sv3d_u" in version:
elev_deg = st.number_input("elev_deg", value=5, min_value=-90, max_value=90)
value_dict["polars_rad"] = np.array([np.deg2rad(90 - elev_deg)] * T)
value_dict["azimuths_rad"] = np.linspace(0, 2 * np.pi, T + 1)[1:]
seed = st.sidebar.number_input(
"seed", value=23, min_value=0, max_value=int(1e9)
)
@@ -151,6 +218,19 @@ if __name__ == "__main__":
os.path.join(SAVE_PATH, version), init_value=True
)
if "sv3d" in version:
plot_save_path = os.path.join(save_path, "plot_3D.png")
plot_3D(
azim=value_dict["azimuths_rad"],
polar=value_dict["polars_rad"],
save_path=plot_save_path,
dynamic=("sv3d_p" in version),
)
st.image(
plot_save_path,
f"3D camera trajectory",
)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling(options=options)

View File

@@ -0,0 +1,132 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv3d_p_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 1280
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- input_key: cond_frames_without_noise
is_trainable: False
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
- input_key: polars_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512
- input_key: azimuths_rad
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 512
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: torch.nn.Identity
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
params:
max_scale: 2.5

View File

@@ -0,0 +1,120 @@
model:
target: sgm.models.diffusion.DiffusionEngine
params:
scale_factor: 0.18215
disable_first_stage_autocast: True
ckpt_path: checkpoints/sv3d_u_image_decoder.safetensors
denoiser_config:
target: sgm.modules.diffusionmodules.denoiser.Denoiser
params:
scaling_config:
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
network_config:
target: sgm.modules.diffusionmodules.video_model.VideoUNet
params:
adm_in_channels: 256
num_classes: sequential
use_checkpoint: True
in_channels: 8
out_channels: 4
model_channels: 320
attention_resolutions: [4, 2, 1]
num_res_blocks: 2
channel_mult: [1, 2, 4, 4]
num_head_channels: 64
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
spatial_transformer_attn_type: softmax-xformers
extra_ff_mix_layer: True
use_spatial_context: True
merge_strategy: learned_with_images
video_kernel_size: [3, 1, 1]
conditioner_config:
target: sgm.modules.GeneralConditioner
params:
emb_models:
- is_trainable: False
input_key: cond_frames_without_noise
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
params:
n_cond_frames: 1
n_copies: 1
open_clip_embedding_config:
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
params:
freeze: True
- input_key: cond_frames
is_trainable: False
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
params:
disable_encoder_autocast: True
n_cond_frames: 1
n_copies: 1
is_ae: True
encoder_config:
target: sgm.models.autoencoder.AutoencoderKLModeOnly
params:
embed_dim: 4
monitor: val/rec_loss
ddconfig:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity
- input_key: cond_aug
is_trainable: False
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
params:
outdim: 256
first_stage_config:
target: sgm.models.autoencoder.AutoencodingEngine
params:
loss_config:
target: torch.nn.Identity
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: torch.nn.Identity
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params:
attn_type: vanilla-xformers
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1, 2, 4, 4 ]
num_res_blocks: 2
attn_resolutions: [ ]
dropout: 0.0
sampler_config:
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
params:
discretization_config:
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
params:
sigma_max: 700.0
guider_config:
target: sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider
params:
max_scale: 2.5

View File

@@ -1,27 +1,29 @@
import math
import os
import sys
from glob import glob
from pathlib import Path
from typing import Optional
from typing import List, Optional
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
import cv2
import imageio
import numpy as np
import torch
from einops import rearrange, repeat
from fire import Fire
from omegaconf import OmegaConf
from PIL import Image
from torchvision.transforms import ToTensor
from scripts.util.detection.nsfw_and_watermark_dectection import \
DeepFloydDataFiltering
from rembg import remove
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.inference.helpers import embed_watermark
from sgm.util import default, instantiate_from_config
from torchvision.transforms import ToTensor
def sample(
input_path: str = "assets/test_image.png", # Can either be image file or folder with image files
num_frames: Optional[int] = None,
num_frames: Optional[int] = None, # 21 for SV3D
num_steps: Optional[int] = None,
version: str = "svd",
fps_id: int = 6,
@@ -31,6 +33,10 @@ def sample(
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
output_folder: Optional[str] = None,
elevations_deg: Optional[float | List[float]] = 10.0, # For SV3D
azimuths_deg: Optional[float | List[float]] = None, # For SV3D
image_frame_ratio: Optional[float] = None,
verbose: Optional[bool] = False,
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
@@ -61,6 +67,24 @@ def sample(
output_folder, "outputs/simple_video_sample/svd_xt_image_decoder/"
)
model_config = "scripts/sampling/configs/svd_xt_image_decoder.yaml"
elif version == "sv3d_u":
num_frames = 21
num_steps = default(num_steps, 50)
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_u/")
model_config = "scripts/sampling/configs/sv3d_u.yaml"
cond_aug = 1e-5
elif version == "sv3d_p":
num_frames = 21
num_steps = default(num_steps, 50)
output_folder = default(output_folder, "outputs/simple_video_sample/sv3d_p/")
model_config = "scripts/sampling/configs/sv3d_p.yaml"
cond_aug = 1e-5
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
elevations_deg = [elevations_deg] * num_frames
polars_rad = [np.deg2rad(90 - e) for e in elevations_deg]
if azimuths_deg is None:
azimuths_deg = np.linspace(0, 360, num_frames + 1)[1:] % 360
azimuths_rad = [np.deg2rad(a) for a in azimuths_deg]
else:
raise ValueError(f"Version {version} does not exist.")
@@ -69,6 +93,7 @@ def sample(
device,
num_frames,
num_steps,
verbose,
)
torch.manual_seed(seed)
@@ -93,20 +118,56 @@ def sample(
raise ValueError
for input_img_path in all_img_paths:
with Image.open(input_img_path) as image:
if "sv3d" in version:
image = Image.open(input_img_path)
if image.mode == "RGBA":
image = image.convert("RGB")
w, h = image.size
pass
else:
# remove bg
image.thumbnail([768, 768], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
if h % 64 != 0 or w % 64 != 0:
width, height = map(lambda x: x - x % 64, (w, h))
image = image.resize((width, height))
print(
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
# resize object in frame
image_arr = np.array(image)
in_w, in_h = image_arr.shape[:2]
ret, mask = cv2.threshold(
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - h // 2 : center - h // 2 + h,
center - w // 2 : center - w // 2 + w,
] = image_arr[y : y + h, x : x + w]
# resize frame to 576x576
rgba = Image.fromarray(padded_image).resize((576, 576), Image.LANCZOS)
# white bg
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
input_image = Image.fromarray((rgb * 255).astype(np.uint8))
image = ToTensor()(image)
image = image * 2.0 - 1.0
else:
with Image.open(input_img_path) as image:
if image.mode == "RGBA":
input_image = image.convert("RGB")
w, h = image.size
if h % 64 != 0 or w % 64 != 0:
width, height = map(lambda x: x - x % 64, (w, h))
input_image = input_image.resize((width, height))
print(
f"WARNING: Your image is of size {h}x{w} which is not divisible by 64. We are resizing to {height}x{width}!"
)
image = ToTensor()(input_image)
image = image * 2.0 - 1.0
image = image.unsqueeze(0).to(device)
H, W = image.shape[2:]
@@ -114,10 +175,14 @@ def sample(
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)
if (H, W) != (576, 1024):
if (H, W) != (576, 1024) and "sv3d" not in version:
print(
"WARNING: The conditioning frame you provided is not 576x1024. This leads to suboptimal performance as model was only trained on 576x1024. Consider increasing `cond_aug`."
)
if (H, W) != (576, 576) and "sv3d" in version:
print(
"WARNING: The conditioning frame you provided is not 576x576. This leads to suboptimal performance as model was only trained on 576x576."
)
if motion_bucket_id > 255:
print(
"WARNING: High motion bucket! This may lead to suboptimal performance."
@@ -130,12 +195,14 @@ def sample(
print("WARNING: Large fps value! This may lead to suboptimal performance.")
value_dict = {}
value_dict["cond_frames_without_noise"] = image
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug
if "sv3d_p" in version:
value_dict["polars_rad"] = polars_rad
value_dict["azimuths_rad"] = azimuths_rad
with torch.no_grad():
with torch.autocast(device):
@@ -177,16 +244,15 @@ def sample(
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
if "sv3d" in version:
samples_x[-1:] = value_dict["cond_frames_without_noise"]
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
os.makedirs(output_folder, exist_ok=True)
base_count = len(glob(os.path.join(output_folder, "*.mp4")))
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
writer = cv2.VideoWriter(
video_path,
cv2.VideoWriter_fourcc(*"MP4V"),
fps_id + 1,
(samples.shape[-1], samples.shape[-2]),
imageio.imwrite(
os.path.join(output_folder, f"{base_count:06d}.jpg"), input_image
)
samples = embed_watermark(samples)
@@ -197,10 +263,8 @@ def sample(
.numpy()
.astype(np.uint8)
)
for frame in vid:
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
writer.write(frame)
writer.release()
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4")
imageio.mimwrite(video_path, vid)
def get_unique_embedder_keys_from_conditioner(conditioner):
@@ -230,12 +294,10 @@ def get_batch(keys, value_dict, N, T, device):
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames":
batch[key] = repeat(value_dict["cond_frames"], "1 ... -> b ...", b=N[0])
elif key == "cond_frames_without_noise":
batch[key] = repeat(
value_dict["cond_frames_without_noise"], "1 ... -> b ...", b=N[0]
)
elif key == "cond_frames" or key == "cond_frames_without_noise":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
elif key == "polars_rad" or key == "azimuths_rad":
batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])
else:
batch[key] = value_dict[key]
@@ -253,6 +315,7 @@ def load_model(
device: str,
num_frames: int,
num_steps: int,
verbose: bool = False,
):
config = OmegaConf.load(config)
if device == "cuda":
@@ -260,6 +323,7 @@ def load_model(
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.verbose = verbose
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames