Adds SV4D code

This commit is contained in:
Vikram Voleti
2024-07-23 20:17:16 +00:00
parent fbdc58cab9
commit abe9ed3d40
16 changed files with 3174 additions and 23 deletions

View File

@@ -17,6 +17,36 @@ import torch.nn as nn
from einops import rearrange, repeat
def get_alpha(
merge_strategy: str,
mix_factor: Optional[torch.Tensor],
image_only_indicator: torch.Tensor,
apply_sigmoid: bool = True,
is_attn: bool = False,
) -> torch.Tensor:
if merge_strategy == "fixed" or merge_strategy == "learned":
alpha = mix_factor
elif merge_strategy == "learned_with_images":
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(mix_factor, "... -> ... 1"),
)
if is_attn:
alpha = rearrange(alpha, "b t -> (b t) 1 1")
else:
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
elif merge_strategy == "fixed_with_images":
alpha = image_only_indicator
if is_attn:
alpha = rearrange(alpha, "b t -> (b t) 1 1")
else:
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
else:
raise NotImplementedError
return torch.sigmoid(alpha) if apply_sigmoid else alpha
def make_beta_schedule(
schedule,
n_timestep,