mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-30 02:44:28 +01:00
add SV4D 2.0 (#440)
* add SV4D 2.0 * add SV4D 2.0 * Combined sv4dv2 and sv4dv2_8views sampling scripts --------- Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
committed by
GitHub
parent
1659a1c09b
commit
c3147b86db
@@ -74,6 +74,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
x: th.Tensor,
|
||||
emb: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
cond_view: Optional[th.Tensor] = None,
|
||||
cond_motion: Optional[th.Tensor] = None,
|
||||
@@ -86,7 +87,7 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
from ...modules.spacetime_attention import (
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
)
|
||||
|
||||
for layer in self:
|
||||
@@ -97,13 +98,30 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
(
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||
),
|
||||
):
|
||||
x = layer(
|
||||
x,
|
||||
context,
|
||||
# cam,
|
||||
emb,
|
||||
time_context,
|
||||
num_video_frames,
|
||||
image_only_indicator,
|
||||
cond_view,
|
||||
cond_motion,
|
||||
time_step,
|
||||
name,
|
||||
)
|
||||
elif isinstance(
|
||||
module,
|
||||
(
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
),
|
||||
):
|
||||
x = layer(
|
||||
x,
|
||||
context,
|
||||
emb,
|
||||
time_context,
|
||||
num_video_frames,
|
||||
image_only_indicator,
|
||||
|
||||
@@ -8,10 +8,10 @@ from ...modules.video_attention import SpatialVideoTransformer
|
||||
from ...modules.spacetime_attention import (
|
||||
BasicTransformerTimeMixBlock,
|
||||
PostHocSpatialTransformerWithTimeMixing,
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion
|
||||
PostHocSpatialTransformerWithTimeMixingAndMotion,
|
||||
)
|
||||
from ...util import default
|
||||
from .util import AlphaBlender # , LegacyAlphaBlenderWithBug, get_alpha
|
||||
from .util import AlphaBlender, get_alpha
|
||||
|
||||
|
||||
class VideoResBlock(ResBlock):
|
||||
@@ -716,11 +716,11 @@ class PostHocResBlockWithTime(ResBlock):
|
||||
)
|
||||
|
||||
if self.time_mix_legacy:
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator)
|
||||
alpha = self.get_alpha_fn(image_only_indicator=image_only_indicator*0.0)
|
||||
x = alpha.to(x.dtype) * x + (1.0 - alpha).to(x.dtype) * x_mix
|
||||
else:
|
||||
x = self.time_mixer(
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator
|
||||
x_spatial=x_mix, x_temporal=x, image_only_indicator=image_only_indicator*0.0
|
||||
)
|
||||
x = rearrange(x, "b c t h w -> (b t) c h w")
|
||||
return x
|
||||
@@ -752,10 +752,14 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
context_dim: Optional[int] = None,
|
||||
time_downup: bool = False,
|
||||
time_context_dim: Optional[int] = None,
|
||||
view_context_dim: Optional[int] = None,
|
||||
motion_context_dim: Optional[int] = None,
|
||||
extra_ff_mix_layer: bool = False,
|
||||
use_spatial_context: bool = False,
|
||||
time_block_merge_strategy: str = "fixed",
|
||||
time_block_merge_factor: float = 0.5,
|
||||
view_block_merge_factor: float = 0.5,
|
||||
motion_block_merge_factor: float = 0.5,
|
||||
spatial_transformer_attn_type: str = "softmax",
|
||||
time_kernel_size: Union[int, List[int]] = 3,
|
||||
use_linear_in_transformer: bool = False,
|
||||
@@ -767,6 +771,9 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
max_ddpm_temb_period: int = 10000,
|
||||
replicate_time_mix_bug: bool = False,
|
||||
use_motion_attention: bool = False,
|
||||
use_camera_emb: bool = False,
|
||||
use_3d_attention: bool = False,
|
||||
separate_motion_merge_factor: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -886,11 +893,17 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
depth=depth,
|
||||
context_dim=context_dim,
|
||||
time_context_dim=time_context_dim,
|
||||
motion_context_dim=motion_context_dim,
|
||||
dropout=dropout,
|
||||
ff_in=extra_ff_mix_layer,
|
||||
use_spatial_context=use_spatial_context,
|
||||
use_camera_emb=use_camera_emb,
|
||||
use_3d_attention=use_3d_attention,
|
||||
separate_motion_merge_factor=separate_motion_merge_factor,
|
||||
adm_in_channels=adm_in_channels,
|
||||
merge_strategy=time_block_merge_strategy,
|
||||
merge_factor=time_block_merge_factor,
|
||||
merge_factor_motion=motion_block_merge_factor,
|
||||
checkpoint=use_checkpoint,
|
||||
use_linear=use_linear_in_transformer,
|
||||
attn_mode=spatial_transformer_attn_type,
|
||||
@@ -899,7 +912,7 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
time_mix_legacy=time_mix_legacy,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
)
|
||||
|
||||
|
||||
else:
|
||||
return PostHocSpatialTransformerWithTimeMixing(
|
||||
ch,
|
||||
@@ -1173,7 +1186,7 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
timesteps: th.Tensor,
|
||||
context: Optional[th.Tensor] = None,
|
||||
y: Optional[th.Tensor] = None,
|
||||
# cam: Optional[th.Tensor] = None,
|
||||
cam: Optional[th.Tensor] = None,
|
||||
time_context: Optional[th.Tensor] = None,
|
||||
num_video_frames: Optional[int] = None,
|
||||
image_only_indicator: Optional[th.Tensor] = None,
|
||||
@@ -1199,7 +1212,7 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
# cam=cam,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
@@ -1213,7 +1226,7 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
# cam=cam,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
@@ -1228,7 +1241,7 @@ class SpatialUNetModelWithTime(nn.Module):
|
||||
h,
|
||||
emb,
|
||||
context=context,
|
||||
# cam=cam,
|
||||
cam=cam,
|
||||
image_only_indicator=image_only_indicator,
|
||||
cond_view=cond_view,
|
||||
cond_motion=cond_motion,
|
||||
|
||||
Reference in New Issue
Block a user