mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
SV4D: reduce the memory consumption and speed up
This commit is contained in:
@@ -16,9 +16,12 @@ from scripts.demo.sv4d_helpers import (
|
||||
initial_model_load,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
run_img2vid_per_step,
|
||||
prepare_sampling,
|
||||
prepare_inputs,
|
||||
do_sample_per_step,
|
||||
sample_sv3d,
|
||||
save_video,
|
||||
preprocess_video,
|
||||
)
|
||||
|
||||
|
||||
@@ -32,11 +35,11 @@ def sample(
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: float = 1e-5,
|
||||
seed: int = 23,
|
||||
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||
device: str = "cuda",
|
||||
elevations_deg: Optional[Union[float, List[float]]] = 10.0,
|
||||
azimuths_deg: Optional[List[float]] = None,
|
||||
image_frame_ratio: Optional[float] = None,
|
||||
image_frame_ratio: Optional[float] = 0.917,
|
||||
verbose: Optional[bool] = False,
|
||||
remove_bg: bool = False,
|
||||
):
|
||||
@@ -89,15 +92,16 @@ def sample(
|
||||
|
||||
# Read input video frames i.e. images at view 0
|
||||
print(f"Reading {input_path}")
|
||||
images_v0 = read_video(
|
||||
processed_input_path = preprocess_video(
|
||||
input_path,
|
||||
remove_bg=remove_bg,
|
||||
n_frames=n_frames,
|
||||
W=W,
|
||||
H=H,
|
||||
remove_bg=remove_bg,
|
||||
output_folder=output_folder,
|
||||
image_frame_ratio=image_frame_ratio,
|
||||
device=device,
|
||||
)
|
||||
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||
|
||||
# Get camera viewpoints
|
||||
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||
@@ -139,7 +143,7 @@ def sample(
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 12
|
||||
save_video(
|
||||
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||
img_matrix[0],
|
||||
@@ -171,7 +175,7 @@ def sample(
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
samples = run_img2vid(
|
||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view
|
||||
version_dict, model, image, seed, polars, azims, cond_motion, cond_view, decoding_t
|
||||
)
|
||||
samples = samples.view(T, V, 3, H, W)
|
||||
for i, t in enumerate(frame_indices):
|
||||
@@ -185,40 +189,48 @@ def sample(
|
||||
frame_indices = t0 + np.arange(T)
|
||||
print(f"Sampling dense frames {frame_indices}")
|
||||
latent_matrix = torch.randn(n_frames, n_views, C, H // F, W // F).to("cuda")
|
||||
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
|
||||
# alternate between forward and backward conditioning
|
||||
forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices = prepare_inputs(
|
||||
frame_indices,
|
||||
img_matrix,
|
||||
v0,
|
||||
view_indices,
|
||||
model,
|
||||
version_dict,
|
||||
seed,
|
||||
polars,
|
||||
azims
|
||||
)
|
||||
|
||||
for step in tqdm(range(num_steps)):
|
||||
frame_indices = frame_indices[
|
||||
::-1
|
||||
].copy() # alternate between forward and backward conditioning
|
||||
t0 = frame_indices[0]
|
||||
image = img_matrix[t0][v0]
|
||||
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||
if step % 2 == 1:
|
||||
c, uc, additional_model_inputs, sampler = forward_inputs
|
||||
frame_indices = forward_frame_indices
|
||||
else:
|
||||
c, uc, additional_model_inputs, sampler = backward_inputs
|
||||
frame_indices = backward_frame_indices
|
||||
noisy_latents = latent_matrix[frame_indices][:, view_indices].flatten(0, 1)
|
||||
samples = run_img2vid_per_step(
|
||||
version_dict,
|
||||
|
||||
samples = do_sample_per_step(
|
||||
model,
|
||||
image,
|
||||
seed,
|
||||
polars,
|
||||
azims,
|
||||
cond_motion,
|
||||
cond_view,
|
||||
step,
|
||||
sampler,
|
||||
noisy_latents,
|
||||
c,
|
||||
uc,
|
||||
step,
|
||||
additional_model_inputs,
|
||||
)
|
||||
samples = samples.view(T, V, C, H // F, W // F)
|
||||
for i, t in enumerate(frame_indices):
|
||||
for j, v in enumerate(view_indices):
|
||||
latent_matrix[t, v] = samples[i, j]
|
||||
|
||||
for t in frame_indices:
|
||||
for v in view_indices:
|
||||
if t != 0 and v != 0:
|
||||
img = decode_latents(model, latent_matrix[t, v][None], T)
|
||||
img_matrix[t][v] = img * 2 - 1
|
||||
img_matrix = decode_latents(model, latent_matrix, img_matrix, frame_indices, view_indices, T)
|
||||
|
||||
# Save output videos
|
||||
for v in view_indices:
|
||||
|
||||
Reference in New Issue
Block a user