mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-24 00:34:20 +01:00
update sv4d sampling script and readme
This commit is contained in:
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
@@ -36,6 +36,20 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
from sgm.util import default, instantiate_from_config
|
||||
|
||||
|
||||
def load_module_gpu(model):
|
||||
model.cuda()
|
||||
|
||||
|
||||
def unload_module_gpu(model):
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def initial_model_load(model):
|
||||
model.model.half()
|
||||
return model
|
||||
|
||||
|
||||
def get_resizing_factor(
|
||||
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||
) -> float:
|
||||
@@ -60,75 +74,11 @@ def get_resizing_factor(
|
||||
return factor
|
||||
|
||||
|
||||
def load_img_for_prediction_no_st(
|
||||
image_path: str,
|
||||
mask_path: str,
|
||||
W: int,
|
||||
H: int,
|
||||
crop_h: int,
|
||||
crop_w: int,
|
||||
device="cuda",
|
||||
) -> torch.Tensor:
|
||||
image = Image.open(image_path)
|
||||
if image is None:
|
||||
return None
|
||||
image = np.array(image).astype(np.float32) / 255
|
||||
h, w = image.shape[:2]
|
||||
rotated = 0
|
||||
|
||||
mask = None
|
||||
if mask_path is not None:
|
||||
mask = Image.open(mask_path)
|
||||
mask = np.array(mask).astype(np.float32) / 255
|
||||
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
|
||||
np.float32
|
||||
)
|
||||
elif image.shape[-1] == 4:
|
||||
mask = image[:, :, 3:]
|
||||
|
||||
if mask is not None:
|
||||
image = image[:, :, :3] * mask + (1 - mask)
|
||||
# if "DAVIS" in image_path:
|
||||
# y, x, _ = np.where(mask > 0)
|
||||
# x_mean, y_mean = np.mean(x), np.mean(y)
|
||||
# else:
|
||||
# x_mean, y_mean = w//2, h//2
|
||||
# h_new = int(max(crop_h, crop_w) * 1.33)
|
||||
# x_min = max(int(x_mean - h_new//2), 0)
|
||||
# y_min = max(int(y_mean - h_new//2), 0)
|
||||
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
|
||||
# h_crop, w_crop = image_cropped.shape[:2]
|
||||
# h_new = max(h_crop, w_crop)
|
||||
# top = max((h_new - h_crop) // 2, 0)
|
||||
# left = max((h_new - w_crop) // 2, 0)
|
||||
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
|
||||
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
|
||||
# image = image_padded
|
||||
# h, w = image.shape[:2]
|
||||
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32)
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
rfs = get_resizing_factor((H, W), (h, w))
|
||||
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
||||
top = (resize_size[0] - H) // 2
|
||||
left = (resize_size[1] - W) // 2
|
||||
|
||||
image = torch.nn.functional.interpolate(
|
||||
image, resize_size, mode="area", antialias=False
|
||||
)
|
||||
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
||||
return image.to(device) * 2.0 - 1.0, rotated
|
||||
|
||||
|
||||
def read_gif(input_path, n_frames):
|
||||
frames = []
|
||||
video = Image.open(input_path)
|
||||
if video.n_frames < n_frames:
|
||||
return frames
|
||||
for img in ImageSequence.Iterator(video):
|
||||
frames.append(img.convert("RGB"))
|
||||
frames.append(img.convert("RGBA"))
|
||||
if len(frames) == n_frames:
|
||||
break
|
||||
return frames
|
||||
@@ -206,16 +156,17 @@ def read_video(
|
||||
print(f"Loading {len(all_img_paths)} video frames...")
|
||||
images = [Image.open(img_path) for img_path in all_img_paths]
|
||||
|
||||
if len(images) < n_frames:
|
||||
images = (images + images[::-1])[:n_frames]
|
||||
|
||||
if len(images) != n_frames:
|
||||
raise ValueError("Input video contains fewer than {n_frames} frames.")
|
||||
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
|
||||
|
||||
# Remove background and crop video frames
|
||||
images_v0 = []
|
||||
for image in images:
|
||||
for t, image in enumerate(images):
|
||||
if remove_bg:
|
||||
if image.mode == "RGBA":
|
||||
pass
|
||||
else:
|
||||
if image.mode != "RGBA":
|
||||
image.thumbnail([W, H], Image.Resampling.LANCZOS)
|
||||
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||
image_arr = np.array(image)
|
||||
@@ -225,11 +176,12 @@ def read_video(
|
||||
)
|
||||
x, y, w, h = cv2.boundingRect(mask)
|
||||
max_size = max(w, h)
|
||||
side_len = (
|
||||
int(max_size / image_frame_ratio)
|
||||
if image_frame_ratio is not None
|
||||
else in_w
|
||||
)
|
||||
if t == 0:
|
||||
side_len = (
|
||||
int(max_size / image_frame_ratio)
|
||||
if image_frame_ratio is not None
|
||||
else in_w
|
||||
)
|
||||
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||
center = side_len // 2
|
||||
padded_image[
|
||||
@@ -239,7 +191,9 @@ def read_video(
|
||||
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
|
||||
rgba_arr = np.array(rgba) / 255.0
|
||||
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||
images = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||
image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||
else:
|
||||
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
|
||||
image = ToTensor()(image).unsqueeze(0).to(device)
|
||||
images_v0.append(image * 2.0 - 1.0)
|
||||
return images_v0
|
||||
@@ -341,11 +295,13 @@ def sample_sv3d(
|
||||
|
||||
|
||||
def decode_latents(model, samples_z, timesteps):
|
||||
load_module_gpu(model.first_stage_model)
|
||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_module_gpu(model.first_stage_model)
|
||||
return samples
|
||||
|
||||
|
||||
@@ -751,6 +707,7 @@ def do_sample(
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_module_gpu(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -758,13 +715,13 @@ def do_sample(
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_module_gpu(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
@@ -805,8 +762,13 @@ def do_sample(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
load_module_gpu(model.model)
|
||||
load_module_gpu(model.denoiser)
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
unload_module_gpu(model.model)
|
||||
unload_module_gpu(model.denoiser)
|
||||
|
||||
load_module_gpu(model.first_stage_model)
|
||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||
samples_x = model.decode_first_stage(
|
||||
samples_z, timesteps=default(decoding_t, T)
|
||||
@@ -814,6 +776,7 @@ def do_sample(
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_module_gpu(model.first_stage_model)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
@@ -850,6 +813,7 @@ def do_sample_per_step(
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_module_gpu(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -857,13 +821,13 @@ def do_sample_per_step(
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_module_gpu(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
@@ -917,6 +881,9 @@ def do_sample_per_step(
|
||||
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
load_module_gpu(model.model)
|
||||
load_module_gpu(model.denoiser)
|
||||
samples_z = sampler.sampler_step(
|
||||
s_in * sigmas[step],
|
||||
s_in * sigmas[step + 1],
|
||||
@@ -926,6 +893,8 @@ def do_sample_per_step(
|
||||
uc,
|
||||
gamma,
|
||||
)
|
||||
unload_module_gpu(model.model)
|
||||
unload_module_gpu(model.denoiser)
|
||||
|
||||
return samples_z
|
||||
|
||||
|
||||
6
scripts/sampling/configs/sv4d.yaml
Normal file → Executable file
6
scripts/sampling/configs/sv4d.yaml
Normal file → Executable file
@@ -93,12 +93,6 @@ model:
|
||||
sigma_sampler_config:
|
||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||
|
||||
# - input_key: cond_aug
|
||||
# is_trainable: False
|
||||
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
# params:
|
||||
# outdim: 256
|
||||
|
||||
- input_key: polar_rad
|
||||
is_trainable: False
|
||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||
|
||||
9
scripts/sampling/simple_video_sample_4d.py
Normal file → Executable file
9
scripts/sampling/simple_video_sample_4d.py
Normal file → Executable file
@@ -13,6 +13,7 @@ from fire import Fire
|
||||
from scripts.demo.sv4d_helpers import (
|
||||
decode_latents,
|
||||
load_model,
|
||||
initial_model_load,
|
||||
read_video,
|
||||
run_img2vid,
|
||||
run_img2vid_per_step,
|
||||
@@ -26,6 +27,7 @@ def sample(
|
||||
output_folder: Optional[str] = "outputs/sv4d",
|
||||
num_steps: Optional[int] = 20,
|
||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||
img_size: int = 576, # image resolution
|
||||
fps_id: int = 6,
|
||||
motion_bucket_id: int = 127,
|
||||
cond_aug: float = 1e-5,
|
||||
@@ -47,7 +49,7 @@ def sample(
|
||||
V = 8 # number of views per sample
|
||||
F = 8 # vae factor to downsize image->latent
|
||||
C = 4
|
||||
H, W = 576, 576
|
||||
H, W = img_size, img_size
|
||||
n_frames = 21 # number of input and output video frames
|
||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||
n_views_sv3d = 21
|
||||
@@ -64,7 +66,7 @@ def sample(
|
||||
"f": F,
|
||||
"options": {
|
||||
"discretization": 1,
|
||||
"cfg": 2.5,
|
||||
"cfg": 3.0,
|
||||
"sigma_min": 0.002,
|
||||
"sigma_max": 700.0,
|
||||
"rho": 7.0,
|
||||
@@ -137,7 +139,7 @@ def sample(
|
||||
for t in range(n_frames):
|
||||
img_matrix[t][0] = images_v0[t]
|
||||
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
|
||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
||||
save_video(
|
||||
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||
img_matrix[0],
|
||||
@@ -155,6 +157,7 @@ def sample(
|
||||
num_steps,
|
||||
verbose,
|
||||
)
|
||||
model = initial_model_load(model)
|
||||
|
||||
# Interleaved sampling for anchor frames
|
||||
t0, v0 = 0, 0
|
||||
|
||||
Reference in New Issue
Block a user