Stable Video Diffusion

This commit is contained in:
Tim Dockhorn
2023-11-21 10:40:21 -08:00
parent 477d8b9a77
commit 059d8e9cd9
59 changed files with 5463 additions and 1691 deletions

View File

@@ -1,13 +1,13 @@
import os
from typing import Union, List, Optional
import math
import os
from typing import List, Optional, Union
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from PIL import Image
from torch import autocast
from sgm.util import append_dims
@@ -20,17 +20,16 @@ class WatermarkEmbedder:
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
def __call__(self, image: torch.Tensor) -> torch.Tensor:
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
image: ([N,] B, RGB, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
@@ -39,6 +38,7 @@ class WatermarkEmbedder:
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
# watermarking libary expects input as cv2 BGR format
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(