mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-30 11:44:24 +01:00
SV3D inference code (#300)
* Makes init changes for SV3D * Small fixes : cond_aug * Fixes SV3D checkpoint, fixes rembg * Black formatting * Adds streamlit demo, fixes simple sample script * Removes SV3D video_decoder, keeps SV3D image_decoder * Updates README * Minor updates * Remove GSO script --------- Co-authored-by: Vikram Voleti <vikram@ip-26-0-153-234.us-west-2.compute.internal>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
@@ -97,3 +97,35 @@ class LinearPredictionGuider(Guider):
|
||||
assert c[k] == uc[k]
|
||||
c_out[k] = c[k]
|
||||
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
|
||||
|
||||
|
||||
class TrianglePredictionGuider(LinearPredictionGuider):
|
||||
def __init__(
|
||||
self,
|
||||
max_scale: float,
|
||||
num_frames: int,
|
||||
min_scale: float = 1.0,
|
||||
period: float | List[float] = 1.0,
|
||||
period_fusing: Literal["mean", "multiply", "max"] = "max",
|
||||
additional_cond_keys: Optional[Union[List[str], str]] = None,
|
||||
):
|
||||
super().__init__(max_scale, num_frames, min_scale, additional_cond_keys)
|
||||
values = torch.linspace(0, 1, num_frames)
|
||||
# Constructs a triangle wave
|
||||
if isinstance(period, float):
|
||||
period = [period]
|
||||
|
||||
scales = []
|
||||
for p in period:
|
||||
scales.append(self.triangle_wave(values, p))
|
||||
|
||||
if period_fusing == "mean":
|
||||
scale = sum(scales) / len(period)
|
||||
elif period_fusing == "multiply":
|
||||
scale = torch.prod(torch.stack(scales), dim=0)
|
||||
elif period_fusing == "max":
|
||||
scale = torch.max(torch.stack(scales), dim=0).values
|
||||
self.scale = (scale * (max_scale - min_scale) + min_scale).unsqueeze(0)
|
||||
|
||||
def triangle_wave(self, values: torch.Tensor, period) -> torch.Tensor:
|
||||
return 2 * (values / period - torch.floor(values / period + 0.5)).abs()
|
||||
|
||||
Reference in New Issue
Block a user