mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-27 09:24:41 +01:00
Adds SV4D code
This commit is contained in:
@@ -17,6 +17,36 @@ import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
|
||||
|
||||
def get_alpha(
|
||||
merge_strategy: str,
|
||||
mix_factor: Optional[torch.Tensor],
|
||||
image_only_indicator: torch.Tensor,
|
||||
apply_sigmoid: bool = True,
|
||||
is_attn: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if merge_strategy == "fixed" or merge_strategy == "learned":
|
||||
alpha = mix_factor
|
||||
elif merge_strategy == "learned_with_images":
|
||||
alpha = torch.where(
|
||||
image_only_indicator.bool(),
|
||||
torch.ones(1, 1, device=image_only_indicator.device),
|
||||
rearrange(mix_factor, "... -> ... 1"),
|
||||
)
|
||||
if is_attn:
|
||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||
else:
|
||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||
elif merge_strategy == "fixed_with_images":
|
||||
alpha = image_only_indicator
|
||||
if is_attn:
|
||||
alpha = rearrange(alpha, "b t -> (b t) 1 1")
|
||||
else:
|
||||
alpha = rearrange(alpha, "b t -> b 1 t 1 1")
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return torch.sigmoid(alpha) if apply_sigmoid else alpha
|
||||
|
||||
|
||||
def make_beta_schedule(
|
||||
schedule,
|
||||
n_timestep,
|
||||
|
||||
Reference in New Issue
Block a user