SV4D: reduce the memory consumption and speed up

This commit is contained in:
ymxie97
2024-08-02 04:59:37 +00:00
parent e0596f1aca
commit 854bd4f0df
3 changed files with 253 additions and 114 deletions

View File

@@ -16,9 +16,12 @@ from scripts.demo.sv4d_helpers import (
initial_model_load,
read_video,
run_img2vid,
run_img2vid_per_step,
prepare_sampling,
prepare_inputs,
do_sample_per_step,
sample_sv3d,
save_video,
preprocess_video,
)
@@ -32,11 +35,11 @@ def sample(
motion_bucket_id: int = 127,
cond_aug: float = 1e-5,
seed: int = 23,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
azimuths_deg: Optional[List[float]] = None,
image_frame_ratio: Optional[float] = None,
image_frame_ratio: Optional[float] = 0.917,
verbose: Optional[bool] = False,
remove_bg: bool = False,
):
@@ -89,15 +92,16 @@ def sample(
# Read input video frames i.e. images at view 0
print(f"Reading {input_path}")
images_v0 = read_video(
processed_input_path = preprocess_video(
input_path,
remove_bg=remove_bg,
n_frames=n_frames,
W=W,
H=H,
remove_bg=remove_bg,
output_folder=output_folder,
image_frame_ratio=image_frame_ratio,
device=device,
)
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
# Get camera viewpoints
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
@@ -139,7 +143,7 @@ def sample(
for t in range(n_frames):
img_matrix[t][0] = images_v0[t]
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
save_video(
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
img_matrix[0],
@@ -171,7 +175,7 @@ def sample(
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
samples = run_img2vid(
version_dict, model, image, seed, polars, azims, cond_motion, cond_view
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
)
samples = samples.view(T, V, 3, H, W)
for i, t in enumerate(frame_indices):
@@ -185,40 +189,48 @@ def sample(
frame_indices = t0 + np.arange(T)
print(f"Sampling dense frames {frame_indices}")
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
# alternate between forward and backward conditioning
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
frame_indices,
img_matrix,
v0,
view_indices,
model,
version_dict,
seed,
polars,
azims
)
for step in tqdm(range(num_steps)):
frame_indices = frame_indices[
::-1
].copy() # alternate between forward and backward conditioning
t0 = frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
if step % 2 == 1:
c, uc, additional_model_inputs, sampler = forward_inputs
frame_indices = forward_frame_indices
else:
c, uc, additional_model_inputs, sampler = backward_inputs
frame_indices = backward_frame_indices
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
samples = run_img2vid_per_step(
version_dict,
samples = do_sample_per_step(
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
step,
sampler,
noisy_latents,
c,
uc,
step,
additional_model_inputs,
)
samples = samples.view(T, V, C, H // F, W // F)
for i, t in enumerate(frame_indices):
for j, v in enumerate(view_indices):
latent_matrix[t, v] = samples[i, j]
for t in frame_indices:
for v in view_indices:
if t != 0 and v != 0:
img = decode_latents(model, latent_matrix[t, v][None], T)
img_matrix[t][v] = img * 2 - 1
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
# Save output videos
for v in view_indices: