mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -37,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
class DeepFloydDataFiltering(object):
|
||||
def __init__(self, verbose: bool = False):
|
||||
def __init__(
|
||||
self, verbose: bool = False, device: torch.device = torch.device("cpu")
|
||||
):
|
||||
super().__init__()
|
||||
self.verbose = verbose
|
||||
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
|
||||
self._device = None
|
||||
self.clip_model, _ = clip.load("ViT-L/14", device=device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
|
||||
@@ -54,7 +57,9 @@ class DeepFloydDataFiltering(object):
|
||||
@torch.inference_mode()
|
||||
def __call__(self, images: torch.Tensor) -> torch.Tensor:
|
||||
imgs = clip_process_images(images)
|
||||
image_features = self.clip_model.encode_image(imgs.to("cpu"))
|
||||
if self._device is None:
|
||||
self._device = next(p for p in self.clip_model.parameters()).device
|
||||
image_features = self.clip_model.encode_image(imgs.to(self._device))
|
||||
image_features = image_features.detach().cpu().numpy().astype(np.float16)
|
||||
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
|
||||
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)
|
||||
|
||||
Reference in New Issue
Block a user