add SV4D 2.0 (#440)

* add SV4D 2.0

* add SV4D 2.0

* Combined sv4dv2 and sv4dv2_8views sampling scripts

---------

Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
chunhanyao-stable
2025-05-20 07:38:11 -07:00
committed by GitHub
parent 1659a1c09b
commit c3147b86db
44 changed files with 1000 additions and 116 deletions

View File

@@ -74,6 +74,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
x: th.Tensor,
emb: th.Tensor,
context: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
image_only_indicator: Optional[th.Tensor] = None,
cond_view: Optional[th.Tensor] = None,
cond_motion: Optional[th.Tensor] = None,
@@ -86,7 +87,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
PostHocSpatialTransformerWithTimeMixingAndMotion,
)
for layer in self:
@@ -97,13 +98,30 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
(
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
),
):
x = layer(
x,
context,
# cam,
emb,
time_context,
num_video_frames,
image_only_indicator,
cond_view,
cond_motion,
time_step,
name,
)
elif isinstance(
module,
(
PostHocSpatialTransformerWithTimeMixingAndMotion,
),
):
x = layer(
x,
context,
emb,
time_context,
num_video_frames,
image_only_indicator,

View File

@@ -8,10 +8,10 @@ from ...modules.video_attention import SpatialVideoTransformer
from ...modules.spacetime_attention import (
BasicTransformerTimeMixBlock,
PostHocSpatialTransformerWithTimeMixing,
PostHocSpatialTransformerWithTimeMixingAndMotion
PostHocSpatialTransformerWithTimeMixingAndMotion,
)
from ...util import default
from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha
from .util import AlphaBlender, get_alpha
class VideoResBlock(ResBlock):
@@ -716,11 +716,11 @@ class PostHocResBlockWithTime(ResBlock):
)
if self.time_mix_legacy:
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
else:
x = self.time_mixer(
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0
)
x = rearrange(x, "b c t h w -> (b t) c h w")
return x
@@ -752,10 +752,14 @@ class SpatialUNetModelWithTime(nn.Module):
context_dim: Optional[int] = None,
time_downup: bool = False,
time_context_dim: Optional[int] = None,
view_context_dim: Optional[int] = None,
motion_context_dim: Optional[int] = None,
extra_ff_mix_layer: bool = False,
use_spatial_context: bool = False,
time_block_merge_strategy: str = "fixed",
time_block_merge_factor: float = 0.5,
view_block_merge_factor: float = 0.5,
motion_block_merge_factor: float = 0.5,
spatial_transformer_attn_type: str = "softmax",
time_kernel_size: Union[int, List[int]] = 3,
use_linear_in_transformer: bool = False,
@@ -767,6 +771,9 @@ class SpatialUNetModelWithTime(nn.Module):
max_ddpm_temb_period: int = 10000,
replicate_time_mix_bug: bool = False,
use_motion_attention: bool = False,
use_camera_emb: bool = False,
use_3d_attention: bool = False,
separate_motion_merge_factor: bool = False,
):
super().__init__()
@@ -886,11 +893,17 @@ class SpatialUNetModelWithTime(nn.Module):
depth=depth,
context_dim=context_dim,
time_context_dim=time_context_dim,
motion_context_dim=motion_context_dim,
dropout=dropout,
ff_in=extra_ff_mix_layer,
use_spatial_context=use_spatial_context,
use_camera_emb=use_camera_emb,
use_3d_attention=use_3d_attention,
separate_motion_merge_factor=separate_motion_merge_factor,
adm_in_channels=adm_in_channels,
merge_strategy=time_block_merge_strategy,
merge_factor=time_block_merge_factor,
merge_factor_motion=motion_block_merge_factor,
checkpoint=use_checkpoint,
use_linear=use_linear_in_transformer,
attn_mode=spatial_transformer_attn_type,
@@ -899,7 +912,7 @@ class SpatialUNetModelWithTime(nn.Module):
time_mix_legacy=time_mix_legacy,
max_time_embed_period=max_ddpm_temb_period,
)
else:
return PostHocSpatialTransformerWithTimeMixing(
ch,
@@ -1173,7 +1186,7 @@ class SpatialUNetModelWithTime(nn.Module):
timesteps: th.Tensor,
context: Optional[th.Tensor] = None,
y: Optional[th.Tensor] = None,
# cam: Optional[th.Tensor] = None,
cam: Optional[th.Tensor] = None,
time_context: Optional[th.Tensor] = None,
num_video_frames: Optional[int] = None,
image_only_indicator: Optional[th.Tensor] = None,
@@ -1199,7 +1212,7 @@ class SpatialUNetModelWithTime(nn.Module):
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
@@ -1213,7 +1226,7 @@ class SpatialUNetModelWithTime(nn.Module):
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,
@@ -1228,7 +1241,7 @@ class SpatialUNetModelWithTime(nn.Module):
h,
emb,
context=context,
# cam=cam,
cam=cam,
image_only_indicator=image_only_indicator,
cond_view=cond_view,
cond_motion=cond_motion,