mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-05 05:44:29 +01:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
787abc0be9 | ||
|
|
0aee97d395 | ||
|
|
fd9d14e02f |
12
README.md
12
README.md
@@ -5,6 +5,18 @@
|
|||||||
## News
|
## News
|
||||||
|
|
||||||
|
|
||||||
|
**Nov 4, 2025**
|
||||||
|
- We are releasing **[Stable Part Diffusion 4D (SP4D)](https://huggingface.co/stabilityai/sp4d)**, a video-to-4D diffusion model for multi-view part video synthesis and animatable 3D asset generation. For research purposes:
|
||||||
|
- **SP4D** was trained to generate 48 RGB frames and part segmentation maps (4 video frames x 12 camera views) at 576x576 resolution, given a 4-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
||||||
|
- Based on our previous 4D model [SV4D 2.0](https://huggingface.co/stabilityai/sv4d2.0), **SP4D** can simultaneously generate multi-view RGB videos as well as the corresponding kinematic part segmentations that are consistent across time and camera views.
|
||||||
|
- The generated part videos can then be used to create animation-ready 3D assets with part-aware rigging capabilities.
|
||||||
|
- Please check our [project page](https://stablepartdiffusion4d.github.io/), [arxiv paper](https://arxiv.org/pdf/2509.10687) and [video summary](https://www.youtube.com/watch?v=FXEFeh8tf0k) for more details.
|
||||||
|
|
||||||
|
**QUICKSTART** :
|
||||||
|
- Setup environment following the SV4D instructions and download [sp4d.safetensors](https://huggingface.co/stabilityai/sp4d) from HuggingFace into `checkpoints/`
|
||||||
|
- Run `python scripts/sampling/simple_video_sample_sp4d.py --input_path assets/sv4d_videos/cows.gif --output_folder outputs` to generate multi-view part videos given the sample input.
|
||||||
|
|
||||||
|
|
||||||
**May 20, 2025**
|
**May 20, 2025**
|
||||||
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
|
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
|
||||||
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
|
||||||
|
|||||||
@@ -724,6 +724,7 @@ def run_img2vid(
|
|||||||
cond_view=None,
|
cond_view=None,
|
||||||
decoding_t=None,
|
decoding_t=None,
|
||||||
cond_mv=True,
|
cond_mv=True,
|
||||||
|
part_maps=False,
|
||||||
):
|
):
|
||||||
options = version_dict["options"]
|
options = version_dict["options"]
|
||||||
H = version_dict["H"]
|
H = version_dict["H"]
|
||||||
@@ -792,6 +793,7 @@ def run_img2vid(
|
|||||||
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
|
force_cond_zero_embeddings=options.get("force_cond_zero_embeddings", None),
|
||||||
return_latents=False,
|
return_latents=False,
|
||||||
decoding_t=decoding_t,
|
decoding_t=decoding_t,
|
||||||
|
part_maps=part_maps,
|
||||||
)
|
)
|
||||||
|
|
||||||
return samples
|
return samples
|
||||||
@@ -921,6 +923,7 @@ def do_sample(
|
|||||||
T=None,
|
T=None,
|
||||||
additional_batch_uc_fields=None,
|
additional_batch_uc_fields=None,
|
||||||
decoding_t=None,
|
decoding_t=None,
|
||||||
|
part_maps=False,
|
||||||
):
|
):
|
||||||
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
force_uc_zero_embeddings = default(force_uc_zero_embeddings, [])
|
||||||
batch2model_input = default(batch2model_input, [])
|
batch2model_input = default(batch2model_input, [])
|
||||||
@@ -989,6 +992,9 @@ def do_sample(
|
|||||||
else:
|
else:
|
||||||
additional_model_inputs[k] = batch[k]
|
additional_model_inputs[k] = batch[k]
|
||||||
|
|
||||||
|
if part_maps:
|
||||||
|
shape = (math.prod(num_samples), C * 2, H // F, W // F)
|
||||||
|
else:
|
||||||
shape = (math.prod(num_samples), C, H // F, W // F)
|
shape = (math.prod(num_samples), C, H // F, W // F)
|
||||||
randn = torch.randn(shape).to("cuda")
|
randn = torch.randn(shape).to("cuda")
|
||||||
|
|
||||||
|
|||||||
210
scripts/sampling/configs/sp4d.yaml
Executable file
210
scripts/sampling/configs/sp4d.yaml
Executable file
@@ -0,0 +1,210 @@
|
|||||||
|
N_TIME: 4
|
||||||
|
N_VIEW: 12
|
||||||
|
N_FRAMES: 48
|
||||||
|
|
||||||
|
model:
|
||||||
|
target: sgm.models.diffusion.DiffusionEngine
|
||||||
|
params:
|
||||||
|
scale_factor: 0.18215
|
||||||
|
en_and_decode_n_samples_a_time: 8
|
||||||
|
disable_first_stage_autocast: True
|
||||||
|
ckpt_path: checkpoints/sp4d.safetensors
|
||||||
|
dual_concat: True
|
||||||
|
denoiser_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser.Denoiser
|
||||||
|
params:
|
||||||
|
scaling_config:
|
||||||
|
target: sgm.modules.diffusionmodules.denoiser_scaling.VScalingWithEDMcNoise
|
||||||
|
|
||||||
|
network_config:
|
||||||
|
target: sgm.modules.diffusionmodules.video_model.DualSpatialUNetWithCrossComm
|
||||||
|
params:
|
||||||
|
unet_config:
|
||||||
|
adm_in_channels: 1280
|
||||||
|
attention_resolutions: [4, 2, 1]
|
||||||
|
channel_mult: [1, 2, 4, 4]
|
||||||
|
context_dim: 1024
|
||||||
|
motion_context_dim: 4
|
||||||
|
extra_ff_mix_layer: True
|
||||||
|
in_channels: 8
|
||||||
|
legacy: False
|
||||||
|
model_channels: 320
|
||||||
|
num_classes: sequential
|
||||||
|
num_head_channels: 64
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_channels: 4
|
||||||
|
replicate_time_mix_bug: True
|
||||||
|
spatial_transformer_attn_type: softmax-xformers
|
||||||
|
time_block_merge_factor: 0.0
|
||||||
|
time_block_merge_strategy: learned_with_images
|
||||||
|
time_kernel_size: [3, 1, 1]
|
||||||
|
time_mix_legacy: False
|
||||||
|
transformer_depth: 1
|
||||||
|
use_checkpoint: False
|
||||||
|
use_linear_in_transformer: True
|
||||||
|
use_spatial_context: True
|
||||||
|
use_spatial_transformer: True
|
||||||
|
separate_motion_merge_factor: True
|
||||||
|
use_motion_attention: True
|
||||||
|
use_3d_attention: True
|
||||||
|
use_camera_emb: True
|
||||||
|
|
||||||
|
conditioner_config:
|
||||||
|
target: sgm.modules.GeneralConditioner
|
||||||
|
params:
|
||||||
|
emb_models:
|
||||||
|
|
||||||
|
- input_key: cond_frames_without_noise
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImagePredictionEmbedder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
open_clip_embedding_config:
|
||||||
|
target: sgm.modules.encoders.modules.FrozenOpenCLIPImageEmbedder
|
||||||
|
params:
|
||||||
|
freeze: True
|
||||||
|
|
||||||
|
- input_key: cond_frames
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
is_trainable: False
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_FRAMES}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
embed_dim: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
monitor: val/rec_loss
|
||||||
|
sigma_cond_config:
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 256
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: polar_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: azimuth_rad
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
|
||||||
|
params:
|
||||||
|
outdim: 512
|
||||||
|
|
||||||
|
- input_key: cond_view
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_VIEW}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
- input_key: cond_motion
|
||||||
|
is_trainable: False
|
||||||
|
target: sgm.modules.encoders.modules.VideoPredictionEmbedderWithEncoder
|
||||||
|
params:
|
||||||
|
is_ae: True
|
||||||
|
n_cond_frames: ${N_TIME}
|
||||||
|
n_copies: 1
|
||||||
|
encoder_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencoderKLModeOnly
|
||||||
|
params:
|
||||||
|
embed_dim: 4
|
||||||
|
monitor: val/rec_loss
|
||||||
|
ddconfig:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
lossconfig:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
sigma_sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
|
||||||
|
|
||||||
|
first_stage_config:
|
||||||
|
target: sgm.models.autoencoder.AutoencodingEngine
|
||||||
|
params:
|
||||||
|
loss_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
regularizer_config:
|
||||||
|
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
|
||||||
|
encoder_config:
|
||||||
|
target: torch.nn.Identity
|
||||||
|
decoder_config:
|
||||||
|
target: sgm.modules.diffusionmodules.model.DecoderDual
|
||||||
|
params:
|
||||||
|
attn_resolutions: []
|
||||||
|
attn_type: vanilla-xformers
|
||||||
|
ch: 128
|
||||||
|
ch_mult: [1, 2, 4, 4]
|
||||||
|
double_z: True
|
||||||
|
dropout: 0.0
|
||||||
|
in_channels: 3
|
||||||
|
num_res_blocks: 2
|
||||||
|
out_ch: 3
|
||||||
|
resolution: 256
|
||||||
|
z_channels: 4
|
||||||
|
|
||||||
|
sampler_config:
|
||||||
|
target: sgm.modules.diffusionmodules.sampling.EulerEDMSampler
|
||||||
|
params:
|
||||||
|
num_steps: 50
|
||||||
|
discretization_config:
|
||||||
|
target: sgm.modules.diffusionmodules.discretizer.EDMDiscretization
|
||||||
|
params:
|
||||||
|
sigma_max: 500.0
|
||||||
|
guider_config:
|
||||||
|
target: sgm.modules.diffusionmodules.guiders.SpatiotemporalPredictionGuider
|
||||||
|
params:
|
||||||
|
max_scale: 1.5
|
||||||
|
min_scale: 1.5
|
||||||
|
num_frames: ${N_FRAMES}
|
||||||
|
num_views: ${N_VIEW}
|
||||||
|
additional_cond_keys: [ cond_view, cond_motion ]
|
||||||
198
scripts/sampling/simple_video_sample_sp4d.py
Executable file
198
scripts/sampling/simple_video_sample_sp4d.py
Executable file
@@ -0,0 +1,198 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from glob import glob
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
sys.path.append(os.path.realpath(os.path.join(os.path.dirname(__file__), "../../")))
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fire import Fire
|
||||||
|
from scripts.demo.sv4d_helpers import (
|
||||||
|
load_model,
|
||||||
|
preprocess_video,
|
||||||
|
read_video,
|
||||||
|
run_img2vid,
|
||||||
|
save_video,
|
||||||
|
)
|
||||||
|
from sgm.modules.encoders.modules import VideoPredictionEmbedderWithEncoder
|
||||||
|
|
||||||
|
sp4d_configs = {
|
||||||
|
"sp4d": {
|
||||||
|
"T": 4, # number of frames per sample
|
||||||
|
"V": 12, # number of views per sample
|
||||||
|
"model_config": "scripts/sampling/configs/sp4d.yaml",
|
||||||
|
"version_dict": {
|
||||||
|
"T": 48,
|
||||||
|
"options": {
|
||||||
|
"discretization": 1,
|
||||||
|
"cfg": 3.0,
|
||||||
|
"min_cfg": 1.5,
|
||||||
|
"num_views": 12,
|
||||||
|
"sigma_min": 0.002,
|
||||||
|
"sigma_max": 700.0,
|
||||||
|
"rho": 7.0,
|
||||||
|
"guider": 2,
|
||||||
|
"force_uc_zero_embeddings": [
|
||||||
|
"cond_frames",
|
||||||
|
"cond_frames_without_noise",
|
||||||
|
"cond_view",
|
||||||
|
"cond_motion",
|
||||||
|
],
|
||||||
|
"additional_guider_kwargs": {
|
||||||
|
"additional_cond_keys": ["cond_view", "cond_motion"]
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
input_path: str = "assets/sv4d_videos/camel.gif", # Can either be image file or folder with image files
|
||||||
|
model_path: Optional[str] = "checkpoints/sp4d.safetensors",
|
||||||
|
output_folder: Optional[str] = "outputs",
|
||||||
|
num_steps: Optional[int] = 50,
|
||||||
|
img_size: int = 512, # image resolution
|
||||||
|
n_frames: int = 4, # number of input and output video frames
|
||||||
|
seed: int = 23,
|
||||||
|
encoding_t: int = 8, # Number of frames encoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
decoding_t: int = 4, # Number of frames decoded at a time! This eats most VRAM. Reduce if necessary.
|
||||||
|
device: str = "cuda",
|
||||||
|
elevations_deg: Optional[List[float]] = 0.0,
|
||||||
|
azimuths_deg: Optional[List[float]] = None,
|
||||||
|
image_frame_ratio: Optional[float] = 0.9,
|
||||||
|
verbose: Optional[bool] = False,
|
||||||
|
remove_bg: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Simple script to generate multiple novel-view videos conditioned on a video `input_path` or multiple frames, one for each
|
||||||
|
image file in folder `input_path`. If you run out of VRAM, try decreasing `decoding_t` and `encoding_t`.
|
||||||
|
"""
|
||||||
|
# Set model config
|
||||||
|
assert os.path.basename(model_path) in [
|
||||||
|
"sp4d.safetensors",
|
||||||
|
]
|
||||||
|
sp4d_model = os.path.splitext(os.path.basename(model_path))[0]
|
||||||
|
config = sp4d_configs[sp4d_model]
|
||||||
|
print(sp4d_model, config)
|
||||||
|
T = config["T"]
|
||||||
|
V = config["V"]
|
||||||
|
model_config = config["model_config"]
|
||||||
|
version_dict = config["version_dict"]
|
||||||
|
F = 8 # vae factor to downsize image->latent
|
||||||
|
C = 4
|
||||||
|
H, W = img_size, img_size
|
||||||
|
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
|
||||||
|
subsampled_views = np.arange(n_views)
|
||||||
|
version_dict["H"] = H
|
||||||
|
version_dict["W"] = W
|
||||||
|
version_dict["C"] = C
|
||||||
|
version_dict["f"] = F
|
||||||
|
version_dict["options"]["num_steps"] = num_steps
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
output_folder = os.path.join(output_folder, sp4d_model)
|
||||||
|
os.makedirs(output_folder, exist_ok=True)
|
||||||
|
|
||||||
|
# Read input video frames i.e. images at view 0
|
||||||
|
print(f"Reading {input_path}")
|
||||||
|
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // (n_frames + 1)
|
||||||
|
processed_input_path = preprocess_video(
|
||||||
|
input_path,
|
||||||
|
remove_bg=remove_bg,
|
||||||
|
n_frames=n_frames,
|
||||||
|
W=W,
|
||||||
|
H=H,
|
||||||
|
output_folder=output_folder,
|
||||||
|
image_frame_ratio=image_frame_ratio,
|
||||||
|
base_count=base_count,
|
||||||
|
)
|
||||||
|
images_v0 = read_video(processed_input_path, n_frames=n_frames, device=device)
|
||||||
|
images_t0 = torch.zeros(n_views, 3, H, W).float().to(device)
|
||||||
|
|
||||||
|
# Get camera viewpoints
|
||||||
|
if isinstance(elevations_deg, float) or isinstance(elevations_deg, int):
|
||||||
|
elevations_deg = [elevations_deg] * n_views
|
||||||
|
assert (
|
||||||
|
len(elevations_deg) == n_views
|
||||||
|
), f"Please provide 1 value, or a list of {n_views} values for elevations_deg! Given {len(elevations_deg)}"
|
||||||
|
if azimuths_deg is None:
|
||||||
|
azimuths_deg = np.linspace(0, 360, n_views + 1)[1:] % 360
|
||||||
|
assert (
|
||||||
|
len(azimuths_deg) == n_views
|
||||||
|
), f"Please provide a list of {n_views} values for azimuths_deg! Given {len(azimuths_deg)}"
|
||||||
|
polars_rad = np.array([np.deg2rad(90 - e) for e in elevations_deg])
|
||||||
|
azimuths_rad = np.array(
|
||||||
|
[np.deg2rad((a - azimuths_deg[-1]) % 360) for a in azimuths_deg]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize image matrix
|
||||||
|
img_matrix = [[None] * n_views for _ in range(n_frames)]
|
||||||
|
for i, v in enumerate(subsampled_views):
|
||||||
|
img_matrix[0][i] = images_t0[v].unsqueeze(0)
|
||||||
|
for t in range(n_frames):
|
||||||
|
img_matrix[t][0] = images_v0[t]
|
||||||
|
|
||||||
|
# Load SV4D++ model
|
||||||
|
model, _ = load_model(
|
||||||
|
model_config,
|
||||||
|
device,
|
||||||
|
version_dict["T"],
|
||||||
|
num_steps,
|
||||||
|
verbose,
|
||||||
|
model_path,
|
||||||
|
)
|
||||||
|
model.en_and_decode_n_samples_a_time = decoding_t
|
||||||
|
for emb in model.conditioner.embedders:
|
||||||
|
if isinstance(emb, VideoPredictionEmbedderWithEncoder):
|
||||||
|
emb.en_and_decode_n_samples_a_time = encoding_t
|
||||||
|
|
||||||
|
# Sampling novel-view videos
|
||||||
|
v0 = 0
|
||||||
|
view_indices = np.arange(V) + 1
|
||||||
|
t0_list = range(0, n_frames - T + 1, T - 1)
|
||||||
|
for t0 in tqdm(t0_list):
|
||||||
|
if t0 + T > n_frames:
|
||||||
|
t0 = n_frames - T
|
||||||
|
frame_indices = t0 + np.arange(T)
|
||||||
|
print(f"Sampling frames {frame_indices}")
|
||||||
|
image = img_matrix[t0][v0]
|
||||||
|
cond_motion = torch.cat([img_matrix[t][v0] for t in frame_indices], 0)
|
||||||
|
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
|
||||||
|
polars = polars_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
azims = azimuths_rad[subsampled_views[1:]][None].repeat(T, 0).flatten()
|
||||||
|
polars = (polars - polars_rad[v0] + torch.pi / 2) % (torch.pi * 2)
|
||||||
|
azims = (azims - azimuths_rad[v0]) % (torch.pi * 2)
|
||||||
|
samples = run_img2vid(
|
||||||
|
version_dict,
|
||||||
|
model,
|
||||||
|
image,
|
||||||
|
seed,
|
||||||
|
polars,
|
||||||
|
azims,
|
||||||
|
cond_motion,
|
||||||
|
cond_view,
|
||||||
|
decoding_t,
|
||||||
|
cond_mv=False,
|
||||||
|
part_maps=True,
|
||||||
|
)
|
||||||
|
samples = samples.view(T, V, 3, H, -1)
|
||||||
|
|
||||||
|
for i, t in enumerate(frame_indices):
|
||||||
|
for j, v in enumerate(view_indices):
|
||||||
|
img_matrix[t][v] = samples[i, j][None] * 2 - 1
|
||||||
|
|
||||||
|
# Save output videos
|
||||||
|
for t in frame_indices:
|
||||||
|
vid_file = os.path.join(output_folder, f"{base_count:06d}_{t:03d}.mp4")
|
||||||
|
print(f"Saving {vid_file}")
|
||||||
|
save_video(
|
||||||
|
vid_file,
|
||||||
|
[img_matrix[t][v] for v in range(1, n_views) if img_matrix[t][v] is not None],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
Fire(sample)
|
||||||
@@ -38,6 +38,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
no_cond_log: bool = False,
|
no_cond_log: bool = False,
|
||||||
compile_model: bool = False,
|
compile_model: bool = False,
|
||||||
en_and_decode_n_samples_a_time: Optional[int] = None,
|
en_and_decode_n_samples_a_time: Optional[int] = None,
|
||||||
|
dual_concat: bool = False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.log_keys = log_keys
|
self.log_keys = log_keys
|
||||||
@@ -47,7 +48,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
)
|
)
|
||||||
model = instantiate_from_config(network_config)
|
model = instantiate_from_config(network_config)
|
||||||
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
self.model = get_obj_from_str(default(network_wrapper, OPENAIUNETWRAPPER))(
|
||||||
model, compile_model=compile_model
|
model, compile_model=compile_model, dual_concat=dual_concat
|
||||||
)
|
)
|
||||||
|
|
||||||
self.denoiser = instantiate_from_config(denoiser_config)
|
self.denoiser = instantiate_from_config(denoiser_config)
|
||||||
|
|||||||
@@ -746,3 +746,170 @@ class Decoder(nn.Module):
|
|||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
h = torch.tanh(h)
|
h = torch.tanh(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class DecoderDual(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
ch,
|
||||||
|
out_ch,
|
||||||
|
ch_mult=(1, 2, 4, 8),
|
||||||
|
num_res_blocks,
|
||||||
|
attn_resolutions,
|
||||||
|
dropout=0.0,
|
||||||
|
resamp_with_conv=True,
|
||||||
|
in_channels,
|
||||||
|
resolution,
|
||||||
|
z_channels,
|
||||||
|
give_pre_end=False,
|
||||||
|
tanh_out=False,
|
||||||
|
use_linear_attn=False,
|
||||||
|
attn_type="vanilla",
|
||||||
|
**ignorekwargs,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
if use_linear_attn:
|
||||||
|
attn_type = "linear"
|
||||||
|
self.ch = ch
|
||||||
|
self.temb_ch = 0
|
||||||
|
self.num_resolutions = len(ch_mult)
|
||||||
|
self.num_res_blocks = num_res_blocks
|
||||||
|
self.resolution = resolution
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.give_pre_end = give_pre_end
|
||||||
|
self.tanh_out = tanh_out
|
||||||
|
|
||||||
|
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||||
|
in_ch_mult = (1,) + tuple(ch_mult)
|
||||||
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
|
logpy.info(
|
||||||
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
make_attn_cls = self._make_attn()
|
||||||
|
make_resblock_cls = self._make_resblock()
|
||||||
|
make_conv_cls = self._make_conv()
|
||||||
|
|
||||||
|
# z to block_in (处理单个 latent)
|
||||||
|
self.conv_in = torch.nn.Conv2d(
|
||||||
|
z_channels, block_in, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
self.mid = nn.Module()
|
||||||
|
self.mid.block_1 = make_resblock_cls(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
self.mid.attn_1 = make_attn_cls(block_in, attn_type=attn_type)
|
||||||
|
self.mid.block_2 = make_resblock_cls(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_in,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
self.up = nn.ModuleList()
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
block = nn.ModuleList()
|
||||||
|
attn = nn.ModuleList()
|
||||||
|
block_out = ch * ch_mult[i_level]
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
block.append(
|
||||||
|
make_resblock_cls(
|
||||||
|
in_channels=block_in,
|
||||||
|
out_channels=block_out,
|
||||||
|
temb_channels=self.temb_ch,
|
||||||
|
dropout=dropout,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
block_in = block_out
|
||||||
|
if curr_res in attn_resolutions:
|
||||||
|
attn.append(make_attn_cls(block_in, attn_type=attn_type))
|
||||||
|
up = nn.Module()
|
||||||
|
up.block = block
|
||||||
|
up.attn = attn
|
||||||
|
if i_level != 0:
|
||||||
|
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||||
|
curr_res = curr_res * 2
|
||||||
|
self.up.insert(0, up) # prepend to get consistent order
|
||||||
|
|
||||||
|
# end
|
||||||
|
self.norm_out = Normalize(block_in)
|
||||||
|
self.conv_out = make_conv_cls(
|
||||||
|
block_in, out_ch, kernel_size=3, stride=1, padding=1
|
||||||
|
)
|
||||||
|
|
||||||
|
def _make_attn(self) -> Callable:
|
||||||
|
return make_attn
|
||||||
|
|
||||||
|
def _make_resblock(self) -> Callable:
|
||||||
|
return ResnetBlock
|
||||||
|
|
||||||
|
def _make_conv(self) -> Callable:
|
||||||
|
return torch.nn.Conv2d
|
||||||
|
|
||||||
|
def get_last_layer(self, **kwargs):
|
||||||
|
return self.conv_out.weight
|
||||||
|
|
||||||
|
def forward(self, z, **kwargs):
|
||||||
|
"""
|
||||||
|
输入 z 的形状应为 (B, 2 * z_channels, H, W)
|
||||||
|
- 其中前一半通道为第一个 latent,后一半通道为第二个 latent
|
||||||
|
- 分离后分别解码,最终在 W 维度拼接
|
||||||
|
"""
|
||||||
|
# 断言检查,确保输入的通道数是 2 倍的 z_channels
|
||||||
|
assert (
|
||||||
|
z.shape[1] == 2 * self.z_shape[1]
|
||||||
|
), f"Expected {2 * self.z_shape[1]} channels, got {z.shape[1]}"
|
||||||
|
|
||||||
|
# 分割 latent 为两个部分
|
||||||
|
z1, z2 = torch.chunk(z, 2, dim=1) # 按照通道维度 (C) 切分
|
||||||
|
|
||||||
|
# 分别解码
|
||||||
|
img1 = self.decode_single(z1, **kwargs)
|
||||||
|
img2 = self.decode_single(z2, **kwargs)
|
||||||
|
|
||||||
|
# 沿着 W 维度拼接
|
||||||
|
output = torch.cat([img1, img2], dim=-1) # 在 width 维度拼接
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def decode_single(self, z, **kwargs):
|
||||||
|
"""解码单个 latent 到一张图像"""
|
||||||
|
self.last_z_shape = z.shape
|
||||||
|
|
||||||
|
# z to block_in
|
||||||
|
h = self.conv_in(z)
|
||||||
|
|
||||||
|
# middle
|
||||||
|
h = self.mid.block_1(h, None, **kwargs)
|
||||||
|
h = self.mid.attn_1(h, **kwargs)
|
||||||
|
h = self.mid.block_2(h, None, **kwargs)
|
||||||
|
|
||||||
|
# upsampling
|
||||||
|
for i_level in reversed(range(self.num_resolutions)):
|
||||||
|
for i_block in range(self.num_res_blocks + 1):
|
||||||
|
h = self.up[i_level].block[i_block](h, None, **kwargs)
|
||||||
|
if len(self.up[i_level].attn) > 0:
|
||||||
|
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||||
|
if i_level != 0:
|
||||||
|
h = self.up[i_level].upsample(h)
|
||||||
|
|
||||||
|
# end
|
||||||
|
h = self.norm_out(h)
|
||||||
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h, **kwargs)
|
||||||
|
if self.tanh_out:
|
||||||
|
h = torch.tanh(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
|
||||||
@@ -13,6 +13,7 @@ from ...modules.spacetime_attention import (
|
|||||||
from ...util import default
|
from ...util import default
|
||||||
from .util import AlphaBlender, get_alpha
|
from .util import AlphaBlender, get_alpha
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
class VideoResBlock(ResBlock):
|
class VideoResBlock(ResBlock):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -1252,3 +1253,157 @@ class SpatialUNetModelWithTime(nn.Module):
|
|||||||
)
|
)
|
||||||
h = h.type(x.dtype)
|
h = h.type(x.dtype)
|
||||||
return self.out(h)
|
return self.out(h)
|
||||||
|
|
||||||
|
|
||||||
|
class CrossNetworkLayer(nn.Module):
|
||||||
|
def __init__(self, feature_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
self.fusion_conv = nn.Sequential(
|
||||||
|
nn.Conv2d(feature_dim * 2, feature_dim, kernel_size=1),
|
||||||
|
nn.ReLU(inplace=True),
|
||||||
|
nn.Conv2d(feature_dim, feature_dim, kernel_size=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, h1: torch.Tensor, h2: torch.Tensor):
|
||||||
|
"""
|
||||||
|
h1, h2: (B, C, H, W)
|
||||||
|
return: (out1, out2), (B, C, H, W)
|
||||||
|
"""
|
||||||
|
fused_input = torch.cat([h1, h2], dim=1) # (B, 2C, H, W)
|
||||||
|
fused_output = self.fusion_conv(fused_input) # (B, C, H, W)
|
||||||
|
out1 = fused_output + h1
|
||||||
|
out2 = fused_output + h2
|
||||||
|
return out1, out2
|
||||||
|
|
||||||
|
|
||||||
|
class DualSpatialUNetWithCrossComm(nn.Module):
|
||||||
|
def __init__(self, unet_config):
|
||||||
|
super().__init__()
|
||||||
|
self.num_classes = unet_config["num_classes"]
|
||||||
|
self.model_channels = unet_config["model_channels"]
|
||||||
|
|
||||||
|
self.net1 = SpatialUNetModelWithTime(**unet_config)
|
||||||
|
self.net2 = SpatialUNetModelWithTime(**unet_config)
|
||||||
|
|
||||||
|
self.input_cross_layers = nn.ModuleList()
|
||||||
|
for block in self.net1.input_blocks:
|
||||||
|
out_ch = self._get_block_out_channels(block)
|
||||||
|
self.input_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||||
|
|
||||||
|
middle_out_ch = self._get_block_out_channels(self.net1.middle_block)
|
||||||
|
self.middle_cross = CrossNetworkLayer(feature_dim=middle_out_ch)
|
||||||
|
|
||||||
|
self.output_cross_layers = nn.ModuleList()
|
||||||
|
for block in self.net1.output_blocks:
|
||||||
|
out_ch = self._get_block_out_channels(block)
|
||||||
|
self.output_cross_layers.append(CrossNetworkLayer(feature_dim=out_ch))
|
||||||
|
|
||||||
|
def _get_block_out_channels(self, block: nn.Module) -> int:
|
||||||
|
mod_list = list(block.children())
|
||||||
|
for m in reversed(mod_list):
|
||||||
|
if hasattr(m, "out_channels"):
|
||||||
|
return m.out_channels
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
m,
|
||||||
|
(SpatialTransformer, PostHocSpatialTransformerWithTimeMixingAndMotion),
|
||||||
|
):
|
||||||
|
return m.in_channels
|
||||||
|
|
||||||
|
if isinstance(m, nn.Conv2d):
|
||||||
|
return m.out_channels
|
||||||
|
|
||||||
|
raise ValueError(f"Cannot determine out_channels from block: {block}")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: th.Tensor,
|
||||||
|
timesteps: th.Tensor,
|
||||||
|
context: Optional[th.Tensor] = None,
|
||||||
|
y: Optional[th.Tensor] = None,
|
||||||
|
cam: Optional[th.Tensor] = None,
|
||||||
|
time_context: Optional[th.Tensor] = None,
|
||||||
|
num_video_frames: Optional[int] = None,
|
||||||
|
image_only_indicator: Optional[th.Tensor] = None,
|
||||||
|
cond_view: Optional[th.Tensor] = None,
|
||||||
|
cond_motion: Optional[th.Tensor] = None,
|
||||||
|
time_step: Optional[int] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
# ============ encoder ============
|
||||||
|
h1, h2 = x[:, : x.shape[1] // 2], x[:, x.shape[1] // 2 :]
|
||||||
|
|
||||||
|
encoder_feats1 = []
|
||||||
|
encoder_feats2 = []
|
||||||
|
|
||||||
|
assert (y is not None) == (
|
||||||
|
self.num_classes is not None
|
||||||
|
), "must specify y if and only if the model is class-conditional -> no, relax this TODO"
|
||||||
|
|
||||||
|
t_emb = timestep_embedding(
|
||||||
|
timesteps, self.model_channels, repeat_only=False
|
||||||
|
) # 21 x 320
|
||||||
|
|
||||||
|
emb = self.net1.time_embed(t_emb) # 21 x 1280
|
||||||
|
time = str(timesteps[0].data.cpu().numpy())
|
||||||
|
|
||||||
|
if self.num_classes is not None:
|
||||||
|
assert y.shape[0] == h1.shape[0]
|
||||||
|
emb = emb + self.net1.label_emb(y) # 21 x 1280
|
||||||
|
|
||||||
|
filtered_args = {
|
||||||
|
"emb": emb,
|
||||||
|
"context": context,
|
||||||
|
"cam": cam,
|
||||||
|
"cond_view": cond_view,
|
||||||
|
"cond_motion": cond_motion,
|
||||||
|
"time_context": time_context,
|
||||||
|
"num_video_frames": num_video_frames,
|
||||||
|
"image_only_indicator": image_only_indicator,
|
||||||
|
"time_step": time_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, (block1, block2) in enumerate(
|
||||||
|
zip(self.net1.input_blocks, self.net2.input_blocks)
|
||||||
|
):
|
||||||
|
h1 = block1(h1, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||||
|
h2 = block2(h2, name="encoder_{}_{}".format(time, i), **filtered_args)
|
||||||
|
|
||||||
|
# cross
|
||||||
|
h1, h2 = self.input_cross_layers[i](h1, h2)
|
||||||
|
|
||||||
|
encoder_feats1.append(h1)
|
||||||
|
encoder_feats2.append(h2)
|
||||||
|
|
||||||
|
# ============ middle block ============
|
||||||
|
h1 = self.net1.middle_block(
|
||||||
|
h1, name="middle_{}_0".format(time, i), **filtered_args
|
||||||
|
)
|
||||||
|
h2 = self.net2.middle_block(
|
||||||
|
h2, name="middle_{}_0".format(time, i), **filtered_args
|
||||||
|
)
|
||||||
|
|
||||||
|
# cross
|
||||||
|
h1, h2 = self.middle_cross(h1, h2)
|
||||||
|
|
||||||
|
# ============ decoder ============
|
||||||
|
for i, (block1, block2) in enumerate(
|
||||||
|
zip(self.net1.output_blocks, self.net2.output_blocks)
|
||||||
|
):
|
||||||
|
skip1 = encoder_feats1.pop()
|
||||||
|
skip2 = encoder_feats2.pop()
|
||||||
|
h1 = torch.cat([h1, skip1], dim=1)
|
||||||
|
h2 = torch.cat([h2, skip2], dim=1)
|
||||||
|
|
||||||
|
h1 = block1(h1, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||||
|
h2 = block2(h2, name="decoder_{}_{}".format(time, i), **filtered_args)
|
||||||
|
|
||||||
|
# cross
|
||||||
|
h1, h2 = self.output_cross_layers[i](h1, h2)
|
||||||
|
|
||||||
|
# ============ output ============
|
||||||
|
out1 = self.net1.out(h1) # shape: (B, out_channels, H, W)
|
||||||
|
out2 = self.net2.out(h2) # same shape
|
||||||
|
out = torch.cat([out1, out2], dim=1)
|
||||||
|
|
||||||
|
return out
|
||||||
@@ -6,7 +6,7 @@ OPENAIUNETWRAPPER = "sgm.modules.diffusionmodules.wrappers.OpenAIWrapper"
|
|||||||
|
|
||||||
|
|
||||||
class IdentityWrapper(nn.Module):
|
class IdentityWrapper(nn.Module):
|
||||||
def __init__(self, diffusion_model, compile_model: bool = False):
|
def __init__(self, diffusion_model, compile_model: bool = False, dual_concat: bool = False):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
compile = (
|
compile = (
|
||||||
torch.compile
|
torch.compile
|
||||||
@@ -15,6 +15,7 @@ class IdentityWrapper(nn.Module):
|
|||||||
else lambda x: x
|
else lambda x: x
|
||||||
)
|
)
|
||||||
self.diffusion_model = compile(diffusion_model)
|
self.diffusion_model = compile(diffusion_model)
|
||||||
|
self.dual_concat = dual_concat
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
return self.diffusion_model(*args, **kwargs)
|
return self.diffusion_model(*args, **kwargs)
|
||||||
@@ -24,6 +25,13 @@ class OpenAIWrapper(IdentityWrapper):
|
|||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
if self.dual_concat:
|
||||||
|
x_1 = x[:, : x.shape[1] // 2]
|
||||||
|
x_2 = x[:, x.shape[1] // 2 :]
|
||||||
|
x_1 = torch.cat((x_1, c.get("concat", torch.Tensor([]).type_as(x_1))), dim=1)
|
||||||
|
x_2 = torch.cat((x_2, c.get("concat", torch.Tensor([]).type_as(x_2))), dim=1)
|
||||||
|
x = torch.cat((x_1, x_2), dim=1)
|
||||||
|
else:
|
||||||
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
||||||
if "cond_view" in c:
|
if "cond_view" in c:
|
||||||
return self.diffusion_model(
|
return self.diffusion_model(
|
||||||
|
|||||||
Reference in New Issue
Block a user