mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
update sv4d sampling script and readme
This commit is contained in:
23
.vscode/launch.json
vendored
Normal file
23
.vscode/launch.json
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Remote Attach",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "attach",
|
||||||
|
"connect": {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 5678
|
||||||
|
},
|
||||||
|
"pathMappings": [
|
||||||
|
{
|
||||||
|
"localRoot": "${workspaceFolder}",
|
||||||
|
"remoteRoot": "."
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
7
README.md
Normal file → Executable file
7
README.md
Normal file → Executable file
@@ -9,9 +9,9 @@
|
|||||||
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
- We are releasing **[Stable Video 4D (SV4D)](https://huggingface.co/stabilityai/sv4d)**, a video-to-4D diffusion model for novel-view video synthesis. For research purposes:
|
||||||
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
- **SV4D** was trained to generate 40 frames (5 video frames x 8 camera views) at 576x576 resolution, given 5 context frames (the input video), and 8 reference views (synthesised from the first frame of the input video, using a multi-view diffusion model like SV3D) of the same size, ideally white-background images with one object.
|
||||||
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
- To generate longer novel-view videos (21 frames), we propose a novel sampling method using SV4D, by first sampling 5 anchor frames and then densely sampling the remaining frames while maintaining temporal consistency.
|
||||||
- Please check our [project page](https://sv4d.github.io), [tech report](https://sv4d.github.io/static/sv4d_technical_report.pdf) and [video summary](https://www.youtube.com/watch?v=RBP8vdAWTgk) for more details.
|
- Please check our [project page](), [tech report]() and [video summary]() for more details.
|
||||||
|
|
||||||
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [sv4d.safetensors](https://huggingface.co/stabilityai/sv4d) and [sv3d_u.safetensors](https://huggingface.co/stabilityai/sv3d) from HuggingFace into `checkpoints/`)
|
**QUICKSTART** : `python scripts/sampling/simple_video_sample_4d.py --input_path assets/test_video1.mp4 --output_folder outputs/sv4d` (after downloading [SV4D](https://huggingface.co/stabilityai/sv4d) and [SV3D_u]((https://huggingface.co/stabilityai/sv3d)) from HuggingFace)
|
||||||
|
|
||||||
To run **SV4D** on a single input video of 21 frames:
|
To run **SV4D** on a single input video of 21 frames:
|
||||||
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
- Download SV3D models (`sv3d_u.safetensors` and `sv3d_p.safetensors`) from [here](https://huggingface.co/stabilityai/sv3d) and SV4D model (`sv4d.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d) to `checkpoints/`
|
||||||
@@ -23,7 +23,8 @@ To run **SV4D** on a single input video of 21 frames:
|
|||||||
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
|
||||||
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
|
||||||
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
|
||||||
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
|
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
|
||||||
|
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
|
|||||||
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
125
scripts/demo/sv4d_helpers.py
Normal file → Executable file
@@ -36,6 +36,20 @@ from sgm.modules.diffusionmodules.sampling import (
|
|||||||
from sgm.util import default, instantiate_from_config
|
from sgm.util import default, instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
|
def load_module_gpu(model):
|
||||||
|
model.cuda()
|
||||||
|
|
||||||
|
|
||||||
|
def unload_module_gpu(model):
|
||||||
|
model.cpu()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
def initial_model_load(model):
|
||||||
|
model.model.half()
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def get_resizing_factor(
|
def get_resizing_factor(
|
||||||
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
|
||||||
) -> float:
|
) -> float:
|
||||||
@@ -60,75 +74,11 @@ def get_resizing_factor(
|
|||||||
return factor
|
return factor
|
||||||
|
|
||||||
|
|
||||||
def load_img_for_prediction_no_st(
|
|
||||||
image_path: str,
|
|
||||||
mask_path: str,
|
|
||||||
W: int,
|
|
||||||
H: int,
|
|
||||||
crop_h: int,
|
|
||||||
crop_w: int,
|
|
||||||
device="cuda",
|
|
||||||
) -> torch.Tensor:
|
|
||||||
image = Image.open(image_path)
|
|
||||||
if image is None:
|
|
||||||
return None
|
|
||||||
image = np.array(image).astype(np.float32) / 255
|
|
||||||
h, w = image.shape[:2]
|
|
||||||
rotated = 0
|
|
||||||
|
|
||||||
mask = None
|
|
||||||
if mask_path is not None:
|
|
||||||
mask = Image.open(mask_path)
|
|
||||||
mask = np.array(mask).astype(np.float32) / 255
|
|
||||||
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
|
|
||||||
np.float32
|
|
||||||
)
|
|
||||||
elif image.shape[-1] == 4:
|
|
||||||
mask = image[:, :, 3:]
|
|
||||||
|
|
||||||
if mask is not None:
|
|
||||||
image = image[:, :, :3] * mask + (1 - mask)
|
|
||||||
# if "DAVIS" in image_path:
|
|
||||||
# y, x, _ = np.where(mask > 0)
|
|
||||||
# x_mean, y_mean = np.mean(x), np.mean(y)
|
|
||||||
# else:
|
|
||||||
# x_mean, y_mean = w//2, h//2
|
|
||||||
# h_new = int(max(crop_h, crop_w) * 1.33)
|
|
||||||
# x_min = max(int(x_mean - h_new//2), 0)
|
|
||||||
# y_min = max(int(y_mean - h_new//2), 0)
|
|
||||||
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
|
|
||||||
# h_crop, w_crop = image_cropped.shape[:2]
|
|
||||||
# h_new = max(h_crop, w_crop)
|
|
||||||
# top = max((h_new - h_crop) // 2, 0)
|
|
||||||
# left = max((h_new - w_crop) // 2, 0)
|
|
||||||
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
|
|
||||||
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
|
|
||||||
# image = image_padded
|
|
||||||
# h, w = image.shape[:2]
|
|
||||||
|
|
||||||
image = image.transpose(2, 0, 1)
|
|
||||||
image = torch.from_numpy(image).to(dtype=torch.float32)
|
|
||||||
image = image.unsqueeze(0)
|
|
||||||
|
|
||||||
rfs = get_resizing_factor((H, W), (h, w))
|
|
||||||
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
|
|
||||||
top = (resize_size[0] - H) // 2
|
|
||||||
left = (resize_size[1] - W) // 2
|
|
||||||
|
|
||||||
image = torch.nn.functional.interpolate(
|
|
||||||
image, resize_size, mode="area", antialias=False
|
|
||||||
)
|
|
||||||
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
|
|
||||||
return image.to(device) * 2.0 - 1.0, rotated
|
|
||||||
|
|
||||||
|
|
||||||
def read_gif(input_path, n_frames):
|
def read_gif(input_path, n_frames):
|
||||||
frames = []
|
frames = []
|
||||||
video = Image.open(input_path)
|
video = Image.open(input_path)
|
||||||
if video.n_frames < n_frames:
|
|
||||||
return frames
|
|
||||||
for img in ImageSequence.Iterator(video):
|
for img in ImageSequence.Iterator(video):
|
||||||
frames.append(img.convert("RGB"))
|
frames.append(img.convert("RGBA"))
|
||||||
if len(frames) == n_frames:
|
if len(frames) == n_frames:
|
||||||
break
|
break
|
||||||
return frames
|
return frames
|
||||||
@@ -206,16 +156,17 @@ def read_video(
|
|||||||
print(f"Loading {len(all_img_paths)} video frames...")
|
print(f"Loading {len(all_img_paths)} video frames...")
|
||||||
images = [Image.open(img_path) for img_path in all_img_paths]
|
images = [Image.open(img_path) for img_path in all_img_paths]
|
||||||
|
|
||||||
|
if len(images) < n_frames:
|
||||||
|
images = (images + images[::-1])[:n_frames]
|
||||||
|
|
||||||
if len(images) != n_frames:
|
if len(images) != n_frames:
|
||||||
raise ValueError("Input video contains fewer than {n_frames} frames.")
|
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
|
||||||
|
|
||||||
# Remove background and crop video frames
|
# Remove background and crop video frames
|
||||||
images_v0 = []
|
images_v0 = []
|
||||||
for image in images:
|
for t, image in enumerate(images):
|
||||||
if remove_bg:
|
if remove_bg:
|
||||||
if image.mode == "RGBA":
|
if image.mode != "RGBA":
|
||||||
pass
|
|
||||||
else:
|
|
||||||
image.thumbnail([W, H], Image.Resampling.LANCZOS)
|
image.thumbnail([W, H], Image.Resampling.LANCZOS)
|
||||||
image = remove(image.convert("RGBA"), alpha_matting=True)
|
image = remove(image.convert("RGBA"), alpha_matting=True)
|
||||||
image_arr = np.array(image)
|
image_arr = np.array(image)
|
||||||
@@ -225,11 +176,12 @@ def read_video(
|
|||||||
)
|
)
|
||||||
x, y, w, h = cv2.boundingRect(mask)
|
x, y, w, h = cv2.boundingRect(mask)
|
||||||
max_size = max(w, h)
|
max_size = max(w, h)
|
||||||
side_len = (
|
if t == 0:
|
||||||
int(max_size / image_frame_ratio)
|
side_len = (
|
||||||
if image_frame_ratio is not None
|
int(max_size / image_frame_ratio)
|
||||||
else in_w
|
if image_frame_ratio is not None
|
||||||
)
|
else in_w
|
||||||
|
)
|
||||||
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
|
||||||
center = side_len // 2
|
center = side_len // 2
|
||||||
padded_image[
|
padded_image[
|
||||||
@@ -239,7 +191,9 @@ def read_video(
|
|||||||
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
|
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
|
||||||
rgba_arr = np.array(rgba) / 255.0
|
rgba_arr = np.array(rgba) / 255.0
|
||||||
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
|
||||||
images = Image.fromarray((rgb * 255).astype(np.uint8))
|
image = Image.fromarray((rgb * 255).astype(np.uint8))
|
||||||
|
else:
|
||||||
|
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
|
||||||
image = ToTensor()(image).unsqueeze(0).to(device)
|
image = ToTensor()(image).unsqueeze(0).to(device)
|
||||||
images_v0.append(image * 2.0 - 1.0)
|
images_v0.append(image * 2.0 - 1.0)
|
||||||
return images_v0
|
return images_v0
|
||||||
@@ -341,11 +295,13 @@ def sample_sv3d(
|
|||||||
|
|
||||||
|
|
||||||
def decode_latents(model, samples_z, timesteps):
|
def decode_latents(model, samples_z, timesteps):
|
||||||
|
load_module_gpu(model.first_stage_model)
|
||||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||||
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
|
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
|
||||||
else:
|
else:
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
unload_module_gpu(model.first_stage_model)
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|
||||||
@@ -751,6 +707,7 @@ def do_sample(
|
|||||||
else:
|
else:
|
||||||
num_samples = [num_samples]
|
num_samples = [num_samples]
|
||||||
|
|
||||||
|
load_module_gpu(model.conditioner)
|
||||||
batch, batch_uc = get_batch(
|
batch, batch_uc = get_batch(
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
value_dict,
|
value_dict,
|
||||||
@@ -758,13 +715,13 @@ def do_sample(
|
|||||||
T=T,
|
T=T,
|
||||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
batch,
|
batch,
|
||||||
batch_uc=batch_uc,
|
batch_uc=batch_uc,
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||||
)
|
)
|
||||||
|
unload_module_gpu(model.conditioner)
|
||||||
|
|
||||||
for k in c:
|
for k in c:
|
||||||
if not k == "crossattn":
|
if not k == "crossattn":
|
||||||
@@ -805,8 +762,13 @@ def do_sample(
|
|||||||
model.model, input, sigma, c, **additional_model_inputs
|
model.model, input, sigma, c, **additional_model_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_module_gpu(model.model)
|
||||||
|
load_module_gpu(model.denoiser)
|
||||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||||
|
unload_module_gpu(model.model)
|
||||||
|
unload_module_gpu(model.denoiser)
|
||||||
|
|
||||||
|
load_module_gpu(model.first_stage_model)
|
||||||
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
if isinstance(model.first_stage_model.decoder, VideoDecoder):
|
||||||
samples_x = model.decode_first_stage(
|
samples_x = model.decode_first_stage(
|
||||||
samples_z, timesteps=default(decoding_t, T)
|
samples_z, timesteps=default(decoding_t, T)
|
||||||
@@ -814,6 +776,7 @@ def do_sample(
|
|||||||
else:
|
else:
|
||||||
samples_x = model.decode_first_stage(samples_z)
|
samples_x = model.decode_first_stage(samples_z)
|
||||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
|
unload_module_gpu(model.first_stage_model)
|
||||||
|
|
||||||
if filter is not None:
|
if filter is not None:
|
||||||
samples = filter(samples)
|
samples = filter(samples)
|
||||||
@@ -850,6 +813,7 @@ def do_sample_per_step(
|
|||||||
else:
|
else:
|
||||||
num_samples = [num_samples]
|
num_samples = [num_samples]
|
||||||
|
|
||||||
|
load_module_gpu(model.conditioner)
|
||||||
batch, batch_uc = get_batch(
|
batch, batch_uc = get_batch(
|
||||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||||
value_dict,
|
value_dict,
|
||||||
@@ -857,13 +821,13 @@ def do_sample_per_step(
|
|||||||
T=T,
|
T=T,
|
||||||
additional_batch_uc_fields=additional_batch_uc_fields,
|
additional_batch_uc_fields=additional_batch_uc_fields,
|
||||||
)
|
)
|
||||||
|
|
||||||
c, uc = model.conditioner.get_unconditional_conditioning(
|
c, uc = model.conditioner.get_unconditional_conditioning(
|
||||||
batch,
|
batch,
|
||||||
batch_uc=batch_uc,
|
batch_uc=batch_uc,
|
||||||
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
force_uc_zero_embeddings=force_uc_zero_embeddings,
|
||||||
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
force_cond_zero_embeddings=force_cond_zero_embeddings,
|
||||||
)
|
)
|
||||||
|
unload_module_gpu(model.conditioner)
|
||||||
|
|
||||||
for k in c:
|
for k in c:
|
||||||
if not k == "crossattn":
|
if not k == "crossattn":
|
||||||
@@ -917,6 +881,9 @@ def do_sample_per_step(
|
|||||||
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
|
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
load_module_gpu(model.model)
|
||||||
|
load_module_gpu(model.denoiser)
|
||||||
samples_z = sampler.sampler_step(
|
samples_z = sampler.sampler_step(
|
||||||
s_in * sigmas[step],
|
s_in * sigmas[step],
|
||||||
s_in * sigmas[step + 1],
|
s_in * sigmas[step + 1],
|
||||||
@@ -926,6 +893,8 @@ def do_sample_per_step(
|
|||||||
uc,
|
uc,
|
||||||
gamma,
|
gamma,
|
||||||
)
|
)
|
||||||
|
unload_module_gpu(model.model)
|
||||||
|
unload_module_gpu(model.denoiser)
|
||||||
|
|
||||||
return samples_z
|
return samples_z
|
||||||
|
|
||||||
|
|||||||
6
scripts/sampling/configs/sv4d.yaml
Normal file → Executable file
6
scripts/sampling/configs/sv4d.yaml
Normal file → Executable file
@@ -93,12 +93,6 @@ model:
|
|||||||
sigma_sampler_config:
|
sigma_sampler_config:
|
||||||
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
# - input_key: cond_aug
|
|
||||||
# is_trainable: False
|
|
||||||
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
|
||||||
# params:
|
|
||||||
# outdim: 256
|
|
||||||
|
|
||||||
- input_key: polar_rad
|
- input_key: polar_rad
|
||||||
is_trainable: False
|
is_trainable: False
|
||||||
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
|||||||
9
scripts/sampling/simple_video_sample_4d.py
Normal file → Executable file
9
scripts/sampling/simple_video_sample_4d.py
Normal file → Executable file
@@ -13,6 +13,7 @@ from fire import Fire
|
|||||||
from scripts.demo.sv4d_helpers import (
|
from scripts.demo.sv4d_helpers import (
|
||||||
decode_latents,
|
decode_latents,
|
||||||
load_model,
|
load_model,
|
||||||
|
initial_model_load,
|
||||||
read_video,
|
read_video,
|
||||||
run_img2vid,
|
run_img2vid,
|
||||||
run_img2vid_per_step,
|
run_img2vid_per_step,
|
||||||
@@ -26,6 +27,7 @@ def sample(
|
|||||||
output_folder: Optional[str] = "outputs/sv4d",
|
output_folder: Optional[str] = "outputs/sv4d",
|
||||||
num_steps: Optional[int] = 20,
|
num_steps: Optional[int] = 20,
|
||||||
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
|
||||||
|
img_size: int = 576, # image resolution
|
||||||
fps_id: int = 6,
|
fps_id: int = 6,
|
||||||
motion_bucket_id: int = 127,
|
motion_bucket_id: int = 127,
|
||||||
cond_aug: float = 1e-5,
|
cond_aug: float = 1e-5,
|
||||||
@@ -47,7 +49,7 @@ def sample(
|
|||||||
V = 8 # number of views per sample
|
V = 8 # number of views per sample
|
||||||
F = 8 # vae factor to downsize image->latent
|
F = 8 # vae factor to downsize image->latent
|
||||||
C = 4
|
C = 4
|
||||||
H, W = 576, 576
|
H, W = img_size, img_size
|
||||||
n_frames = 21 # number of input and output video frames
|
n_frames = 21 # number of input and output video frames
|
||||||
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
n_views_sv3d = 21
|
n_views_sv3d = 21
|
||||||
@@ -64,7 +66,7 @@ def sample(
|
|||||||
"f": F,
|
"f": F,
|
||||||
"options": {
|
"options": {
|
||||||
"discretization": 1,
|
"discretization": 1,
|
||||||
"cfg": 2.5,
|
"cfg": 3.0,
|
||||||
"sigma_min": 0.002,
|
"sigma_min": 0.002,
|
||||||
"sigma_max": 700.0,
|
"sigma_max": 700.0,
|
||||||
"rho": 7.0,
|
"rho": 7.0,
|
||||||
@@ -137,7 +139,7 @@ def sample(
|
|||||||
for t in range(n_frames):
|
for t in range(n_frames):
|
||||||
img_matrix[t][0] = images_v0[t]
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
|
||||||
save_video(
|
save_video(
|
||||||
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
|
||||||
img_matrix[0],
|
img_matrix[0],
|
||||||
@@ -155,6 +157,7 @@ def sample(
|
|||||||
num_steps,
|
num_steps,
|
||||||
verbose,
|
verbose,
|
||||||
)
|
)
|
||||||
|
model = initial_model_load(model)
|
||||||
|
|
||||||
# Interleaved sampling for anchor frames
|
# Interleaved sampling for anchor frames
|
||||||
t0, v0 = 0, 0
|
t0, v0 = 0, 0
|
||||||
|
|||||||
2
sgm/modules/spacetime_attention.py
Normal file → Executable file
2
sgm/modules/spacetime_attention.py
Normal file → Executable file
@@ -593,4 +593,4 @@ class PostHocSpatialTransformerWithTimeMixingAndMotion(SpatialTransformer):
|
|||||||
if not self.use_linear:
|
if not self.use_linear:
|
||||||
x = self.proj_out(x)
|
x = self.proj_out(x)
|
||||||
out = x + x_in
|
out = x + x_in
|
||||||
return out
|
return out
|
||||||
Reference in New Issue
Block a user