Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

View File

@@ -37,10 +37,13 @@ def clip_process_images(images: torch.Tensor) -> torch.Tensor:
class DeepFloydDataFiltering(object):
def __init__(self, verbose: bool = False):
def __init__(
self, verbose: bool = False, device: torch.device = torch.device("cpu")
):
super().__init__()
self.verbose = verbose
self.clip_model, _ = clip.load("ViT-L/14", device="cpu")
self._device = None
self.clip_model, _ = clip.load("ViT-L/14", device=device)
self.clip_model.eval()
self.cpu_w_weights, self.cpu_w_biases = load_model_weights(
@@ -54,7 +57,9 @@ class DeepFloydDataFiltering(object):
@torch.inference_mode()
def __call__(self, images: torch.Tensor) -> torch.Tensor:
imgs = clip_process_images(images)
image_features = self.clip_model.encode_image(imgs.to("cpu"))
if self._device is None:
self._device = next(p for p in self.clip_model.parameters()).device
image_features = self.clip_model.encode_image(imgs.to(self._device))
image_features = image_features.detach().cpu().numpy().astype(np.float16)
p_pred = predict_proba(image_features, self.cpu_p_weights, self.cpu_p_biases)
w_pred = predict_proba(image_features, self.cpu_w_weights, self.cpu_w_biases)