mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
update sv4d sampling script and readme
This commit is contained in:
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
@@ -36,6 +36,20 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
from sgm.util import default, instantiate_from_config
|
||||
|
||||
|
||||
def load_module_gpu(model):
|
||||
model.cuda()
|
||||
|
||||
|
||||
def unload_module_gpu(model):
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def initial_model_load(model):
|
||||
model.model.half()
|
||||
return model
|
||||
|
||||
|
||||
def get_resizing_factor(
|
||||
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||
) -> float:
|
||||
@@ -60,75 +74,11 @@ def get_resizing_factor(
|
||||
return factor
|
||||
|
||||
|
||||
def load_img_for_prediction_no_st(
|
||||
image_path: str,
|
||||
mask_path: str,
|
||||
W: int,
|
||||
H: int,
|
||||
crop_h: int,
|
||||
crop_w: int,
|
||||
device="cuda",
|
||||
) -> torch.Tensor:
|
||||
image = Image.open(image_path)
|
||||
if image is None:
|
||||
return None
|
||||
image = np.array(image).astype(np.float32) / 255
|
||||
h, w = image.shape[:2]
|
||||
rotated = 0
|
||||
|
||||
mask = None
|
||||
if mask_path is not None:
|
||||
mask = Image.open(mask_path)
|
||||
mask = np.array(mask).astype(np.float32) / 255
|
||||
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
|
||||
np.float32
|
||||
)
|
||||
elif image.shape[-1] == 4:
|
||||
mask = image[:, :, 3:]
|
||||
|
||||
if mask is not None:
|
||||
image = image[:, :, :3] * mask + (1 - mask)
|
||||
# if "DAVIS" in image_path:
|
||||
# y, x, _ = np.where(mask > 0)
|
||||
# x_mean, y_mean = np.mean(x), np.mean(y)
|
||||
# else:
|
||||
# x_mean, y_mean = w//2, h//2
|
||||
# h_new = int(max(crop_h, crop_w) * 1.33)
|
||||
# x_min = max(int(x_mean - h_new//2), 0)
|
||||
# y_min = max(int(y_mean - h_new//2), 0)
|
||||
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
|
||||
# h_crop, w_crop = image_cropped.shape[:2]
|
||||
# h_new = max(h_crop, w_crop)
|
||||
# top = max((h_new - h_crop) // 2, 0)
|
||||
# left = max((h_new - w_crop) // 2, 0)
|
||||
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
|
||||
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
|
||||
# image = image_padded
|
||||
# h, w = image.shape[:2]
|
||||
|
||||
image = image.transpose(2, 0, 1)
|
||||
image = torch.from_numpy(image).to(dtype=torch.float32)
|
||||
image = image.unsqueeze(0)
|
||||
|
||||
rfs = get_resizing_factor((H, W), (h, w))
|
||||
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
||||
top = (resize_size[0] - H) // 2
|
||||
left = (resize_size[1] - W) // 2
|
||||
|
||||
image = torch.nn.functional.interpolate(
|
||||
image, resize_size, mode="area", antialias=False
|
||||
)
|
||||
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
||||
return image.to(device) * 2.0 - 1.0, rotated
|
||||
|
||||
|
||||
def read_gif(input_path, n_frames):
|
||||
frames = []
|
||||
video = Image.open(input_path)
|
||||
if video.n_frames < n_frames:
|
||||
return frames
|
||||
for img in ImageSequence.Iterator(video):
|
||||
frames.append(img.convert("RGB"))
|
||||
frames.append(img.convert("RGBA"))
|
||||
if len(frames) == n_frames:
|
||||
break
|
||||
return frames
|
||||
@@ -206,16 +156,17 @@ def read_video(
|
||||
print(f"Loading {len(all_img_paths)} video frames...")
|
||||
images = [Image.open(img_path) for img_path in all_img_paths]
|
||||
|
||||
if len(images) < n_frames:
|
||||
images = (images + images[::-1])[:n_frames]
|
||||
|
||||
if len(images) != n_frames:
|
||||
raise ValueError("Input video contains fewer than {n_frames} frames.")
|
||||
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
|
||||
|
||||
# Remove background and crop video frames
|
||||
images_v0 = []
|
||||
for image in images:
|
||||
for t, image in enumerate(images):
|
||||
if remove_bg:
|
||||
if image.mode == "RGBA":
|
||||
pass
|
||||
else:
|
||||
if image.mode != "RGBA":
|
||||
image.thumbnail([W, H], Image.Resampling.LANCZOS)
|
||||
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||
image_arr = np.array(image)
|
||||
@@ -225,11 +176,12 @@ def read_video(
|
||||
)
|
||||
x, y, w, h = cv2.boundingRect(mask)
|
||||
max_size = max(w, h)
|
||||
side_len = (
|
||||
int(max_size / image_frame_ratio)
|
||||
if image_frame_ratio is not None
|
||||
else in_w
|
||||
)
|
||||
if t == 0:
|
||||
side_len = (
|
||||
int(max_size / image_frame_ratio)
|
||||
if image_frame_ratio is not None
|
||||
else in_w
|
||||
)
|
||||
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||
center = side_len // 2
|
||||
padded_image[
|
||||
@@ -239,7 +191,9 @@ def read_video(
|
||||
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
|
||||
rgba_arr = np.array(rgba) / 255.0
|
||||
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||
images = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||
image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||
else:
|
||||
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
|
||||
image = ToTensor()(image).unsqueeze(0).to(device)
|
||||
images_v0.append(image * 2.0 - 1.0)
|
||||
return images_v0
|
||||
@@ -341,11 +295,13 @@ def sample_sv3d(
|
||||
|
||||
|
||||
def decode_latents(model, samples_z, timesteps):
|
||||
load_module_gpu(model.first_stage_model)
|
||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_module_gpu(model.first_stage_model)
|
||||
return samples
|
||||
|
||||
|
||||
@@ -751,6 +707,7 @@ def do_sample(
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_module_gpu(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -758,13 +715,13 @@ def do_sample(
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_module_gpu(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
@@ -805,8 +762,13 @@ def do_sample(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
load_module_gpu(model.model)
|
||||
load_module_gpu(model.denoiser)
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
unload_module_gpu(model.model)
|
||||
unload_module_gpu(model.denoiser)
|
||||
|
||||
load_module_gpu(model.first_stage_model)
|
||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||
samples_x = model.decode_first_stage(
|
||||
samples_z, timesteps=default(decoding_t, T)
|
||||
@@ -814,6 +776,7 @@ def do_sample(
|
||||
else:
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
unload_module_gpu(model.first_stage_model)
|
||||
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
@@ -850,6 +813,7 @@ def do_sample_per_step(
|
||||
else:
|
||||
num_samples = [num_samples]
|
||||
|
||||
load_module_gpu(model.conditioner)
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -857,13 +821,13 @@ def do_sample_per_step(
|
||||
T=T,
|
||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||
)
|
||||
|
||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||
batch,
|
||||
batch_uc=batch_uc,
|
||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||
)
|
||||
unload_module_gpu(model.conditioner)
|
||||
|
||||
for k in c:
|
||||
if not k == "crossattn":
|
||||
@@ -917,6 +881,9 @@ def do_sample_per_step(
|
||||
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
load_module_gpu(model.model)
|
||||
load_module_gpu(model.denoiser)
|
||||
samples_z = sampler.sampler_step(
|
||||
s_in * sigmas[step],
|
||||
s_in * sigmas[step + 1],
|
||||
@@ -926,6 +893,8 @@ def do_sample_per_step(
|
||||
uc,
|
||||
gamma,
|
||||
)
|
||||
unload_module_gpu(model.model)
|
||||
unload_module_gpu(model.denoiser)
|
||||
|
||||
return samples_z
|
||||
|
||||
|
||||
Reference in New Issue
Block a user