mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
Watermark encoder expects images in BGR channel order (matching cv2 imread). This fix reduces the watermark artifacts.
This commit is contained in:
@@ -83,7 +83,7 @@ class GetWatermarkMatch:
|
|||||||
def __call__(self, x: np.ndarray) -> np.ndarray:
|
def __call__(self, x: np.ndarray) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
Detects the number of matching bits the predefined watermark with one
|
Detects the number of matching bits the predefined watermark with one
|
||||||
or multiple images. Images should be in cv2 format, e.g. h x w x c.
|
or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: ([B], h w, c) in range [0, 255]
|
x: ([B], h w, c) in range [0, 255]
|
||||||
@@ -94,7 +94,6 @@ class GetWatermarkMatch:
|
|||||||
squeeze = len(x.shape) == 3
|
squeeze = len(x.shape) == 3
|
||||||
if squeeze:
|
if squeeze:
|
||||||
x = x[None, ...]
|
x = x[None, ...]
|
||||||
x = np.flip(x, axis=-1)
|
|
||||||
|
|
||||||
bs = x.shape[0]
|
bs = x.shape[0]
|
||||||
detected = np.empty((bs, self.num_bits), dtype=bool)
|
detected = np.empty((bs, self.num_bits), dtype=bool)
|
||||||
|
|||||||
@@ -43,19 +43,19 @@ class WatermarkEmbedder:
|
|||||||
Returns:
|
Returns:
|
||||||
same as input but watermarked
|
same as input but watermarked
|
||||||
"""
|
"""
|
||||||
# watermarking libary expects input as cv2 format
|
# watermarking libary expects input as cv2 BGR format
|
||||||
squeeze = len(image.shape) == 4
|
squeeze = len(image.shape) == 4
|
||||||
if squeeze:
|
if squeeze:
|
||||||
image = image[None, ...]
|
image = image[None, ...]
|
||||||
n = image.shape[0]
|
n = image.shape[0]
|
||||||
image_np = rearrange(
|
image_np = rearrange(
|
||||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||||
).numpy()
|
).numpy()[:,:,:,::-1]
|
||||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||||
for k in range(image_np.shape[0]):
|
for k in range(image_np.shape[0]):
|
||||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||||
image = torch.from_numpy(
|
image = torch.from_numpy(
|
||||||
rearrange(image_np, "(n b) h w c -> n b c h w", n=n)
|
rearrange(image_np[:,:,:,::-1], "(n b) h w c -> n b c h w", n=n)
|
||||||
).to(image.device)
|
).to(image.device)
|
||||||
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
image = torch.clamp(image / 255, min=0.0, max=1.0)
|
||||||
if squeeze:
|
if squeeze:
|
||||||
|
|||||||
Reference in New Issue
Block a user