Files
generative-models/scripts/demo/sv4d_helpers.py
2024-08-02 05:17:14 +00:00

1293 lines
43 KiB
Python
Executable File

import math
import os
from glob import glob
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import cv2
import imageio
import numpy as np
import torch
import torchvision.transforms as TT
from einops import rearrange, repeat
from omegaconf import ListConfig, OmegaConf
from PIL import Image, ImageSequence
from rembg import remove
from torch import autocast
from torchvision.transforms import ToTensor
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
from sgm.modules.diffusionmodules.guiders import (
LinearPredictionGuider,
SpatiotemporalPredictionGuider,
TrapezoidPredictionGuider,
TrianglePredictionGuider,
VanillaCFG,
)
from sgm.modules.diffusionmodules.sampling import (
DPMPP2MSampler,
DPMPP2SAncestralSampler,
EulerAncestralSampler,
EulerEDMSampler,
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.util import default, instantiate_from_config
def load_module_gpu(model):
model.cuda()
def unload_module_gpu(model):
model.cpu()
torch.cuda.empty_cache()
def initial_model_load(model):
model.model.half()
return model
def get_resizing_factor(
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
) -> float:
r_bound = desired_shape[1] / desired_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if aspect_r >= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r < 1.0:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = min(desired_shape) / min(current_shape)
else:
if aspect_r > 1:
factor = max(desired_shape) / min(current_shape)
else:
factor = max(desired_shape) / max(current_shape)
return factor
def read_gif(input_path, n_frames):
frames = []
video = Image.open(input_path)
for img in ImageSequence.Iterator(video):
frames.append(img.convert("RGBA"))
if len(frames) == n_frames:
break
return frames
def read_mp4(input_path, n_frames):
frames = []
vidcap = cv2.VideoCapture(input_path)
success, image = vidcap.read()
while success:
frames.append(Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)))
success, image = vidcap.read()
if len(frames) == n_frames:
break
return frames
def save_img(file_name, img):
output_dir = os.path.dirname(file_name)
os.makedirs(output_dir, exist_ok=True)
imageio.imwrite(
file_name,
(((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8),
)
def save_video(file_name, imgs, fps=10):
output_dir = os.path.dirname(file_name)
os.makedirs(output_dir, exist_ok=True)
img_grid = [
(((img[0].permute(1, 2, 0) + 1) / 2).cpu().numpy() * 255.0).astype(np.uint8)
for img in imgs
]
if file_name.endswith(".gif"):
imageio.mimwrite(file_name, img_grid, fps=fps, loop=0)
else:
imageio.mimwrite(file_name, img_grid, fps=fps)
def read_video(
input_path: str,
n_frames: int,
device: str = "cuda",
):
path = Path(input_path)
is_video_file = False
all_img_paths = []
if path.is_file():
if any([input_path.endswith(x) for x in [".gif", ".mp4"]]):
is_video_file = True
else:
raise ValueError("Path is not a valid video file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)[:n_frames]
elif "*" in input_path:
all_img_paths = sorted(glob(input_path))[:n_frames]
else:
raise ValueError
if is_video_file and input_path.endswith(".gif"):
images = read_gif(input_path, n_frames)[:n_frames]
elif is_video_file and input_path.endswith(".mp4"):
images = read_mp4(input_path, n_frames)[:n_frames]
else:
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]
if len(images) < n_frames:
images = (images + images[::-1])[:n_frames]
if len(images) != n_frames:
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
images_v0 = []
for image in images:
image = ToTensor()(image).unsqueeze(0).to(device)
images_v0.append(image * 2.0 - 1.0)
return images_v0
def preprocess_video(input_path, remove_bg=False, n_frames=21, W=576, H=576, output_folder=None, image_frame_ratio = 0.917):
print(f"preprocess {input_path}")
if output_folder is None:
output_folder = os.path.dirname(input_path)
path = Path(input_path)
is_video_file = False
all_img_paths = []
if path.is_file():
if any([input_path.endswith(x) for x in [".gif", ".mp4"]]):
is_video_file = True
else:
raise ValueError("Path is not a valid video file.")
elif path.is_dir():
all_img_paths = sorted(
[
f
for f in path.iterdir()
if f.is_file() and f.suffix.lower() in [".jpg", ".jpeg", ".png"]
]
)[:n_frames]
elif "*" in input_path:
all_img_paths = sorted(glob(input_path))[:n_frames]
else:
raise ValueError
if is_video_file and input_path.endswith(".gif"):
images = read_gif(input_path, n_frames)[:n_frames]
elif is_video_file and input_path.endswith(".mp4"):
images = read_mp4(input_path, n_frames)[:n_frames]
else:
print(f"Loading {len(all_img_paths)} video frames...")
images = [Image.open(img_path) for img_path in all_img_paths]
if len(images) != n_frames:
raise ValueError(f"Input video contains {len(images)} frames, fewer than {n_frames} frames.")
# Remove background
for i, image in enumerate(images):
if remove_bg:
if image.mode == "RGBA":
pass
else:
# image.thumbnail([W, H], Image.Resampling.LANCZOS)
image = remove(image.convert("RGBA"), alpha_matting=True)
images[i] = image
# Crop video frames, assume the object is already in the center of the image
white_thresh = 250
images_v0 = []
box_coord = [np.inf, np.inf, 0, 0]
for image in images:
image_arr = np.array(image)
in_w, in_h = image_arr.shape[:2]
original_center = (in_w // 2, in_h // 2)
if image.mode == "RGBA":
ret, mask = cv2.threshold(
np.array(image.split()[-1]), 0, 255, cv2.THRESH_BINARY
)
else:
# assume the input image has white background
ret, mask = cv2.threshold(
(np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255, 0, 255, cv2.THRESH_BINARY
)
x, y, w, h = cv2.boundingRect(mask)
box_coord[0] = min(box_coord[0], x)
box_coord[1] = min(box_coord[1], y)
box_coord[2] = max(box_coord[2], x + w)
box_coord[3] = max(box_coord[3], y + h)
box_square = max(original_center[0] - box_coord[0], original_center[1] - box_coord[1])
box_square = max(box_square, box_coord[2] - original_center[0])
box_square = max(box_square, box_coord[3] - original_center[1])
x, y, w, h = original_center[0] - box_square, original_center[1] - box_square, 2 * box_square, 2 * box_square
box_size = box_square * 2
for image in images:
if image.mode == "RGB":
image = image.convert("RGBA")
image_arr = np.array(image)
side_len = (
int(box_size / image_frame_ratio)
if image_frame_ratio is not None
else in_w
)
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
center = side_len // 2
padded_image[
center - box_size // 2 : center - box_size // 2 + box_size,
center - box_size // 2 : center - box_size // 2 + box_size,
] = image_arr[x : x + w, y : y + h]
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
# rgba = image.resize((W, H), Image.LANCZOS)
rgba_arr = np.array(rgba) / 255.0
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
image = (rgb * 255).astype(np.uint8)
images_v0.append(image)
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4")
imageio.mimwrite(processed_file, images_v0, fps=10)
return processed_file
def sample_sv3d(
image,
num_frames: Optional[int] = None, # 21 for SV3D
num_steps: Optional[int] = None,
version: str = "sv3d_u",
fps_id: int = 6,
motion_bucket_id: int = 127,
cond_aug: float = 0.02,
decoding_t: int = 14, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
device: str = "cuda",
polar_rad: Optional[Union[float, List[float]]] = None,
azim_rad: Optional[List[float]] = None,
verbose: Optional[bool] = False,
sv3d_model=None,
):
"""
Simple script to generate a single sample conditioned on an image `input_path` or multiple images, one for each
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t`.
"""
if sv3d_model is None:
if version == "sv3d_u":
model_config = "scripts/sampling/configs/sv3d_u.yaml"
elif version == "sv3d_p":
model_config = "scripts/sampling/configs/sv3d_p.yaml"
else:
raise ValueError(f"Version {version} does not exist.")
model, filter = load_model(
model_config,
device,
num_frames,
num_steps,
verbose,
)
else:
model = sv3d_model
load_module_gpu(model)
H, W = image.shape[2:]
F = 8
C = 4
shape = (num_frames, C, H // F, W // F)
value_dict = {}
value_dict["cond_frames_without_noise"] = image
value_dict["motion_bucket_id"] = motion_bucket_id
value_dict["fps_id"] = fps_id
value_dict["cond_aug"] = cond_aug
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
if "sv3d_p" in version:
value_dict["polars_rad"] = polar_rad
value_dict["azimuths_rad"] = azim_rad
with torch.no_grad():
with torch.autocast(device):
batch, batch_uc = get_batch_sv3d(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[1, num_frames],
T=num_frames,
device=device,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=[
"cond_frames",
"cond_frames_without_noise",
],
)
for k in ["crossattn", "concat"]:
uc[k] = repeat(uc[k], "b ... -> b t ...", t=num_frames)
uc[k] = rearrange(uc[k], "b t ... -> (b t) ...", t=num_frames)
c[k] = repeat(c[k], "b ... -> b t ...", t=num_frames)
c[k] = rearrange(c[k], "b t ... -> (b t) ...", t=num_frames)
randn = torch.randn(shape, device=device)
additional_model_inputs = {}
additional_model_inputs["image_only_indicator"] = torch.zeros(
2, num_frames
).to(device)
additional_model_inputs["num_video_frames"] = batch["num_video_frames"]
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = model.sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
model.en_and_decode_n_samples_a_time = decoding_t
samples_x = model.decode_first_stage(samples_z)
samples_x[-1:] = value_dict["cond_frames_without_noise"]
samples = torch.clamp(samples_x, min=-1.0, max=1.0)
unload_module_gpu(model)
return samples
def decode_latents(model, samples_z, img_matrix, frame_indices, view_indices, timesteps):
load_module_gpu(model.first_stage_model)
for t in frame_indices:
for v in view_indices:
if t != 0 and v != 0:
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(samples_z[t, v][None], timesteps=timesteps)
else:
samples_x = model.decode_first_stage(samples_z[t, v][None])
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
img_matrix[t][v] = samples * 2 - 1
unload_module_gpu(model.first_stage_model)
return img_matrix
def init_embedder_options_no_st(keys, init_dict, prompt=None, negative_prompt=None):
# Hardcoded demo settings; might undergo some changes in the future
value_dict = {}
for key in keys:
if key == "txt":
if prompt is None:
prompt = "A professional photograph of an astronaut riding a pig"
if negative_prompt is None:
negative_prompt = ""
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
if key == "original_size_as_tuple":
orig_width = init_dict["orig_width"]
orig_height = init_dict["orig_height"]
value_dict["orig_width"] = orig_width
value_dict["orig_height"] = orig_height
if key == "crop_coords_top_left":
crop_coord_top = 0
crop_coord_left = 0
value_dict["crop_coords_top"] = crop_coord_top
value_dict["crop_coords_left"] = crop_coord_left
if key == "aesthetic_score":
value_dict["aesthetic_score"] = 6.0
value_dict["negative_aesthetic_score"] = 2.5
if key == "target_size_as_tuple":
value_dict["target_width"] = init_dict["target_width"]
value_dict["target_height"] = init_dict["target_height"]
if key in ["fps_id", "fps"]:
fps = 6
value_dict["fps"] = fps
value_dict["fps_id"] = fps - 1
if key == "motion_bucket_id":
mb_id = 127
value_dict["motion_bucket_id"] = mb_id
if key == "noise_level":
value_dict["noise_level"] = 0
return value_dict
def get_discretization_no_st(discretization, options, key=1):
if discretization == "LegacyDDPMDiscretization":
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
}
elif discretization == "EDMDiscretization":
sigma_min = options.get("sigma_min", 0.03)
sigma_max = options.get("sigma_max", 14.61)
rho = options.get("rho", 3.0)
discretization_config = {
"target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization",
"params": {
"sigma_min": sigma_min,
"sigma_max": sigma_max,
"rho": rho,
},
}
return discretization_config
def get_guider_no_st(options, key):
guider = [
"VanillaCFG",
"IdentityGuider",
"LinearPredictionGuider",
"TrianglePredictionGuider",
"TrapezoidPredictionGuider",
"SpatiotemporalPredictionGuider",
][options.get("guider", 2)]
additional_guider_kwargs = (
options["additional_guider_kwargs"]
if "additional_guider_kwargs" in options
else {}
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale_schedule = "Identity"
if scale_schedule == "Identity":
scale = options.get("cfg", 5.0)
scale_schedule_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentitySchedule",
"params": {"scale": scale},
}
elif scale_schedule == "Oscillating":
small_scale = 4.0
large_scale = 16.0
sigma_cutoff = 1.0
scale_schedule_config = {
"target": "sgm.modules.diffusionmodules.guiders.OscillatingSchedule",
"params": {
"small_scale": small_scale,
"large_scale": large_scale,
"sigma_cutoff": sigma_cutoff,
},
}
else:
raise NotImplementedError
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.VanillaCFG",
"params": {
"scale_schedule_config": scale_schedule_config,
**additional_guider_kwargs,
},
}
elif guider == "LinearPredictionGuider":
max_scale = options.get("cfg", 1.5)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.LinearPredictionGuider",
"params": {
"max_scale": max_scale,
"num_frames": options["num_frames"],
**additional_guider_kwargs,
},
}
elif guider == "TrianglePredictionGuider":
max_scale = options.get("cfg", 1.5)
period = options.get("period", 1.0)
period_fusing = options.get("period_fusing", "max")
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.TrianglePredictionGuider",
"params": {
"max_scale": max_scale,
"num_frames": options["num_frames"],
"period": period,
"period_fusing": period_fusing,
**additional_guider_kwargs,
},
}
elif guider == "TrapezoidPredictionGuider":
max_scale = options.get("cfg", 1.5)
edge_perc = options.get("edge_perc", 0.1)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.TrapezoidPredictionGuider",
"params": {
"max_scale": max_scale,
"num_frames": options["num_frames"],
"edge_perc": edge_perc,
**additional_guider_kwargs,
},
}
elif guider == "SpatiotemporalPredictionGuider":
max_scale = options.get("cfg", 1.5)
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider",
"params": {
"max_scale": max_scale,
"num_frames": options["num_frames"],
**additional_guider_kwargs,
},
}
else:
raise NotImplementedError
return guider_config
def get_sampler_no_st(sampler_name, steps, discretization_config, guider_config, key=1):
if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler":
s_churn = 0.0
s_tmin = 0.0
s_tmax = 999.0
s_noise = 1.0
if sampler_name == "EulerEDMSampler":
sampler = EulerEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=False,
)
elif sampler_name == "HeunEDMSampler":
sampler = HeunEDMSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
s_churn=s_churn,
s_tmin=s_tmin,
s_tmax=s_tmax,
s_noise=s_noise,
verbose=False,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = 1.0
eta = 1.0
if sampler_name == "EulerAncestralSampler":
sampler = EulerAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=False,
)
elif sampler_name == "DPMPP2SAncestralSampler":
sampler = DPMPP2SAncestralSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
eta=eta,
s_noise=s_noise,
verbose=False,
)
elif sampler_name == "DPMPP2MSampler":
sampler = DPMPP2MSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
verbose=False,
)
elif sampler_name == "LinearMultistepSampler":
order = 4
sampler = LinearMultistepSampler(
num_steps=steps,
discretization_config=discretization_config,
guider_config=guider_config,
order=order,
verbose=False,
)
else:
raise ValueError(f"unknown sampler {sampler_name}!")
return sampler
def init_sampling_no_st(
key=1,
options: Optional[Dict[str, int]] = None,
):
options = {} if options is None else options
num_rows, num_cols = 1, 1
steps = options.get("num_steps", 40)
sampler = [
"EulerEDMSampler",
"HeunEDMSampler",
"EulerAncestralSampler",
"DPMPP2SAncestralSampler",
"DPMPP2MSampler",
"LinearMultistepSampler",
][options.get("sampler", 0)]
discretization = [
"LegacyDDPMDiscretization",
"EDMDiscretization",
][options.get("discretization", 1)]
discretization_config = get_discretization_no_st(
discretization, options=options, key=key
)
guider_config = get_guider_no_st(options=options, key=key)
sampler = get_sampler_no_st(
sampler, steps, discretization_config, guider_config, key=key
)
return sampler, num_rows, num_cols
def run_img2vid(
version_dict,
model,
image,
seed=23,
polar_rad=[10] * 21,
azim_rad=np.linspace(0, 360, 21 + 1)[1:],
cond_motion=None,
cond_view=None,
decoding_t=None,
):
options = version_dict["options"]
H = version_dict["H"]
W = version_dict["W"]
T = version_dict["T"]
C = version_dict["C"]
F = version_dict["f"]
init_dict = {
"orig_width": 576,
"orig_height": 576,
"target_width": W,
"target_height": H,
}
ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner))
value_dict = init_embedder_options_no_st(
ukeys,
init_dict,
negative_prompt=options.get("negative_promt", ""),
prompt="A 3D model.",
)
if "fps" not in ukeys:
value_dict["fps"] = 6
value_dict["is_image"] = 0
value_dict["is_webvid"] = 0
value_dict["image_only_indicator"] = 0
cond_aug = 0.00
if cond_motion is not None:
value_dict["cond_frames_without_noise"] = cond_motion
value_dict["cond_frames"] = (
cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1)
)
value_dict["cond_motion"] = cond_motion
value_dict["cond_view"] = cond_view
else:
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug
value_dict["polar_rad"] = polar_rad
value_dict["azimuth_rad"] = azim_rad
value_dict["rotated"] = False
value_dict["cond_motion"] = cond_motion
value_dict["cond_view"] = cond_view
# seed_everything(seed)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling_no_st(options=options)
num_samples = num_rows * num_cols
samples = do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
T=T,
batch2model_input=["num_video_frames", "image_only_indicator"],
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
return_latents=False,
decoding_t=decoding_t,
)
return samples
def prepare_inputs(frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims):
load_module_gpu(model.conditioner)
forward_frame_indices = frame_indices.copy()
t0 = forward_frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
forward_inputs = prepare_sampling(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
)
# backward sampling
backward_frame_indices = frame_indices[
::-1
].copy()
t0 = backward_frame_indices[0]
image = img_matrix[t0][v0]
cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0)
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
backward_inputs = prepare_sampling(
version_dict,
model,
image,
seed,
polars,
azims,
cond_motion,
cond_view,
)
unload_module_gpu(model.conditioner)
return forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
T=None,
additional_batch_uc_fields=None,
decoding_t=None,
):
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
batch2model_input = default(batch2model_input, [])
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
if T is not None:
num_samples = [num_samples, T]
else:
num_samples = [num_samples]
load_module_gpu(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
unload_module_gpu(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
if k == "image_only_indicator":
assert T is not None
if isinstance(
sampler.guider,
(
VanillaCFG,
LinearPredictionGuider,
TrianglePredictionGuider,
TrapezoidPredictionGuider,
SpatiotemporalPredictionGuider,
),
):
additional_model_inputs[k] = torch.zeros(
num_samples[0] * 2, num_samples[1]
).to("cuda")
else:
additional_model_inputs[k] = torch.zeros(num_samples).to(
"cuda"
)
else:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_module_gpu(model.model)
unload_module_gpu(model.denoiser)
load_module_gpu(model.first_stage_model)
model.en_and_decode_n_samples_a_time = decoding_t
if isinstance(model.first_stage_model.decoder, VideoDecoder):
samples_x = model.decode_first_stage(
samples_z, timesteps=default(decoding_t, T)
)
else:
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_module_gpu(model.first_stage_model)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples
def prepare_sampling_(
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings: Optional[List] = None,
force_cond_zero_embeddings: Optional[List] = None,
batch2model_input: List = None,
T=None,
additional_batch_uc_fields=None,
):
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
batch2model_input = default(batch2model_input, [])
additional_batch_uc_fields = default(additional_batch_uc_fields, [])
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
if T is not None:
num_samples = [num_samples, T]
else:
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
T=T,
additional_batch_uc_fields=additional_batch_uc_fields,
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
force_cond_zero_embeddings=force_cond_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
if k == "image_only_indicator":
assert T is not None
if isinstance(
sampler.guider,
(
VanillaCFG,
LinearPredictionGuider,
TrianglePredictionGuider,
TrapezoidPredictionGuider,
SpatiotemporalPredictionGuider,
),
):
additional_model_inputs[k] = torch.zeros(
num_samples[0] * 2, num_samples[1]
).to("cuda")
else:
additional_model_inputs[k] = torch.zeros(num_samples).to(
"cuda"
)
else:
additional_model_inputs[k] = batch[k]
return c, uc, additional_model_inputs
def do_sample_per_step(model, sampler, noisy_latents, c, uc, step, additional_model_inputs):
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
noisy_latents_scaled, s_in, sigmas, num_sigmas, _, _ = (
sampler.prepare_sampling_loop(
noisy_latents.clone(), c, uc, sampler.num_steps
)
)
if step == 0:
latents = noisy_latents_scaled
else:
latents = noisy_latents
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
gamma = (
min(sampler.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
else 0.0
)
load_module_gpu(model.model)
load_module_gpu(model.denoiser)
samples_z = sampler.sampler_step(
s_in * sigmas[step],
s_in * sigmas[step + 1],
denoiser,
latents,
c,
uc,
gamma,
)
return samples_z
def prepare_sampling(
version_dict,
model,
image,
seed=23,
polar_rad=[10] * 21,
azim_rad=np.linspace(0, 360, 21 + 1)[1:],
cond_motion=None,
cond_view=None,
):
options = version_dict["options"]
H = version_dict["H"]
W = version_dict["W"]
T = version_dict["T"]
C = version_dict["C"]
F = version_dict["f"]
init_dict = {
"orig_width": 576,
"orig_height": 576,
"target_width": W,
"target_height": H,
}
ukeys = set(get_unique_embedder_keys_from_conditioner(model.conditioner))
value_dict = init_embedder_options_no_st(
ukeys,
init_dict,
negative_prompt=options.get("negative_promt", ""),
prompt="A 3D model.",
)
if "fps" not in ukeys:
value_dict["fps"] = 6
value_dict["is_image"] = 0
value_dict["is_webvid"] = 0
value_dict["image_only_indicator"] = 0
cond_aug = 0.00
if cond_motion is not None:
value_dict["cond_frames_without_noise"] = cond_motion
value_dict["cond_frames"] = (
cond_motion[:, None].repeat(1, cond_view.shape[0], 1, 1, 1).flatten(0, 1)
)
value_dict["cond_motion"] = cond_motion
value_dict["cond_view"] = cond_view
else:
value_dict["cond_frames_without_noise"] = image
value_dict["cond_frames"] = image + cond_aug * torch.randn_like(image)
value_dict["cond_aug"] = cond_aug
value_dict["polar_rad"] = polar_rad
value_dict["azimuth_rad"] = azim_rad
value_dict["rotated"] = False
value_dict["cond_motion"] = cond_motion
value_dict["cond_view"] = cond_view
# seed_everything(seed)
options["num_frames"] = T
sampler, num_rows, num_cols = init_sampling_no_st(options=options)
num_samples = num_rows * num_cols
c, uc, additional_model_inputs = prepare_sampling_(
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=options.get("force_uc_zero_embeddings", None),
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
batch2model_input=["num_video_frames", "image_only_indicator"],
T=T,
)
return c, uc, additional_model_inputs, sampler
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
def get_batch_sv3d(keys, value_dict, N, T, device):
batch = {}
batch_uc = {}
for key in keys:
if key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(int(math.prod(N)))
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to(device),
"1 -> b",
b=math.prod(N),
)
elif key == "cond_frames" or key == "cond_frames_without_noise":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
elif key == "polars_rad" or key == "azimuths_rad":
batch[key] = torch.tensor(value_dict[key]).to(device).repeat(N[0])
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def get_batch(
keys,
value_dict: dict,
N: Union[List, ListConfig],
device: str = "cuda",
T: int = None,
additional_batch_uc_fields: List[str] = [],
):
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = [value_dict["prompt"]] * math.prod(N)
batch_uc["txt"] = [value_dict["negative_prompt"]] * math.prod(N)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(math.prod(N), 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(math.prod(N), 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]])
.to(device)
.repeat(math.prod(N), 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(math.prod(N), 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(math.prod(N), 1)
)
elif key == "fps":
batch[key] = (
torch.tensor([value_dict["fps"]]).to(device).repeat(math.prod(N))
)
elif key == "fps_id":
batch[key] = (
torch.tensor([value_dict["fps_id"]]).to(device).repeat(math.prod(N))
)
elif key == "motion_bucket_id":
batch[key] = (
torch.tensor([value_dict["motion_bucket_id"]])
.to(device)
.repeat(math.prod(N))
)
elif key == "pool_image":
batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=math.prod(N)).to(
device, dtype=torch.half
)
elif key == "is_image":
batch[key] = (
torch.tensor([value_dict["is_image"]])
.to(device)
.repeat(math.prod(N))
.long()
)
elif key == "is_webvid":
batch[key] = (
torch.tensor([value_dict["is_webvid"]])
.to(device)
.repeat(math.prod(N))
.long()
)
elif key == "cond_aug":
batch[key] = repeat(
torch.tensor([value_dict["cond_aug"]]).to("cuda"),
"1 -> b",
b=math.prod(N),
)
elif (
key == "cond_frames"
or key == "cond_frames_without_noise"
or key == "back_frames"
):
# batch[key] = repeat(value_dict[key], "1 ... -> b ...", b=N[0])
batch[key] = value_dict[key]
elif key == "interpolation_context":
batch[key] = repeat(
value_dict["interpolation_context"], "b ... -> (b n) ...", n=N[1]
)
elif key == "start_frame":
assert T is not None
batch[key] = repeat(value_dict[key], "b ... -> (b t) ...", t=T)
elif key == "polar_rad" or key == "azimuth_rad":
batch[key] = (
torch.tensor(value_dict[key]).to(device).repeat(math.prod(N) // T)
)
elif key == "rotated":
batch[key] = (
torch.tensor([value_dict["rotated"]]).to(device).repeat(math.prod(N))
)
else:
batch[key] = value_dict[key]
if T is not None:
batch["num_video_frames"] = T
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
elif key in additional_batch_uc_fields and key not in batch_uc:
batch_uc[key] = copy.copy(batch[key])
return batch, batch_uc
def load_model(
config: str,
device: str,
num_frames: int,
num_steps: int,
verbose: bool = False,
):
config = OmegaConf.load(config)
if device == "cuda":
config.model.params.conditioner_config.params.emb_models[
0
].params.open_clip_embedding_config.params.init_device = device
config.model.params.sampler_config.params.verbose = verbose
config.model.params.sampler_config.params.num_steps = num_steps
config.model.params.sampler_config.params.guider_config.params.num_frames = (
num_frames
)
if device == "cuda":
with torch.device(device):
model = instantiate_from_config(config.model).to(device).eval()
else:
model = instantiate_from_config(config.model).to(device).eval()
filter = DeepFloydDataFiltering(verbose=False, device=device)
return model, filter