update sv4d sampling script and readme

This commit is contained in:
Chun-Han Yao
2024-07-31 18:42:28 +00:00
parent 863665548f
commit 1cd0cbaff4
6 changed files with 81 additions and 91 deletions

125
scripts/demo/sv4d_helpers.py Normal file → Executable file
View File

@@ -36,6 +36,20 @@ from sgm.modules.diffusionmodules.sampling import (
from sgm.util import default, instantiate_from_config
def load_module_gpu(model):
model.cuda()
def unload_module_gpu(model):
model.cpu()
torch.cuda.empty_cache()
def initial_model_load(model):
model.model.half()
return model
def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
@@ -60,75 +74,11 @@ def get_resizing_factor(
return factor
def load_img_for_prediction_no_st(
image_path: str,
mask_path: str,
W: int,
H: int,
crop_h: int,
crop_w: int,
device="cuda",
) -> torch.Tensor:
image = Image.open(image_path)
if image is None:
return None
image = np.array(image).astype(np.float32) / 255
h, w = image.shape[:2]
rotated = 0
mask = None
if mask_path is not None:
mask = Image.open(mask_path)
mask = np.array(mask).astype(np.float32) / 255
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
np.float32
)
elif image.shape[-1] == 4:
mask = image[:, :, 3:]
if mask is not None:
image = image[:, :, :3] * mask + (1 - mask)
# if "DAVIS" in image_path:
# y, x, _ = np.where(mask > 0)
# x_mean, y_mean = np.mean(x), np.mean(y)
# else:
# x_mean, y_mean = w//2, h//2
# h_new = int(max(crop_h, crop_w) * 1.33)
# x_min = max(int(x_mean - h_new//2), 0)
# y_min = max(int(y_mean - h_new//2), 0)
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
# h_crop, w_crop = image_cropped.shape[:2]
# h_new = max(h_crop, w_crop)
# top = max((h_new - h_crop) // 2, 0)
# left = max((h_new - w_crop) // 2, 0)
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
# image = image_padded
# h, w = image.shape[:2]
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)
rfs = get_resizing_factor((H, W), (h, w))
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
top = (resize_size[0] - H) // 2
left = (resize_size[1] - W) // 2
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
return image.to(device) * 2.0 - 1.0, rotated
def read_gif(input_path, n_frames):
frames = []
video = Image.open(input_path)
if video.n_frames < n_frames:
return frames
for img in ImageSequence.Iterator(video):
frames.append(img.convert("RGB"))
frames.append(img.convert("RGBA"))
if len(frames) == n_frames:
break
return frames
@@ -206,16 +156,17 @@ def read_video(
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]
if len(images) < n_frames:
images = (images + images[::-1])[:n_frames]
if len(images) != n_frames:
raise ValueError("Input video contains fewer than {n_frames} frames.")
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
# Remove background and crop video frames
images_v0 = []
for image in images:
for t, image in enumerate(images):
if remove_bg:
if image.mode == "RGBA":
pass
else:
if image.mode != "RGBA":
image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
image_arr = np.array(image)
@@ -225,11 +176,12 @@ def read_video(
)
x, y, w, h = cv2.boundingRect(mask)
max_size = max(w, h)
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
if t == 0:
side_len = (
int(max_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
@@ -239,7 +191,9 @@ def read_video(
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
images = Image.fromarray((rgb * 255).astype(np.uint8))
image = Image.fromarray((rgb * 255).astype(np.uint8))
else:
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
image = ToTensor()(image).unsqueeze(0).to(device)
images_v0.append(image * 2.0 - 1.0)
return images_v0
@@ -341,11 +295,13 @@ def sample_sv3d(
def decode_latents(model, samples_z, timesteps):
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
return samples
@@ -751,6 +707,7 @@ def do_sample(
else:
num_samples = [num_samples]
load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -758,13 +715,13 @@ def do_sample(
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)
for k in c:
if not k == "crossattn":
@@ -805,8 +762,13 @@ def do_sample(
model.model, input, sigma, c, **additional_model_inputs
)
load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
load_module_gpu(model.first_stage_model)
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T)
@@ -814,6 +776,7 @@ def do_sample(
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
if filter is not None:
samples = filter(samples)
@@ -850,6 +813,7 @@ def do_sample_per_step(
else:
num_samples = [num_samples]
load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -857,13 +821,13 @@ def do_sample_per_step(
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)
for k in c:
if not k == "crossattn":
@@ -917,6 +881,9 @@ def do_sample_per_step(
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
else 0.0
)
load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler.sampler_step(
s_in * sigmas[step],
s_in * sigmas[step + 1],
@@ -926,6 +893,8 @@ def do_sample_per_step(
uc,
gamma,
)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
return samples_z