Watermark encoder expects images in BGR channel order (matching cv2 imread). This fix reduces the watermark artifacts.

This commit is contained in:
pharmapsychotic
2023-07-05 12:05:14 -05:00
parent ae18ba3e87
commit 5df4d9893c
2 changed files with 4 additions and 5 deletions

View File

@@ -83,7 +83,7 @@ class GetWatermarkMatch:
def __call__(self, x: np.ndarray) -> np.ndarray: def __call__(self, x: np.ndarray) -> np.ndarray:
""" """
Detects the number of matching bits the predefined watermark with one Detects the number of matching bits the predefined watermark with one
or multiple images. Images should be in cv2 format, e.g. h x w x c. or multiple images. Images should be in cv2 format, e.g. h x w x c BGR.
Args: Args:
x: ([B], h w, c) in range [0, 255] x: ([B], h w, c) in range [0, 255]
@@ -94,7 +94,6 @@ class GetWatermarkMatch:
squeeze = len(x.shape) == 3 squeeze = len(x.shape) == 3
if squeeze: if squeeze:
x = x[None, ...] x = x[None, ...]
x = np.flip(x, axis=-1)
bs = x.shape[0] bs = x.shape[0]
detected = np.empty((bs, self.num_bits), dtype=bool) detected = np.empty((bs, self.num_bits), dtype=bool)

View File

@@ -43,19 +43,19 @@ class WatermarkEmbedder:
Returns: Returns:
same as input but watermarked same as input but watermarked
""" """
# watermarking libary expects input as cv2 format # watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4 squeeze = len(image.shape) == 4
if squeeze: if squeeze:
image = image[None, ...] image = image[None, ...]
n = image.shape[0] n = image.shape[0]
image_np = rearrange( image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c" (255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy() ).numpy()[:,:,:,::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]): for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy( image = torch.from_numpy(
rearrange(image_np, "(n b) h w c -> n b c h w", n=n) rearrange(image_np[:,:,:,::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device) ).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0) image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze: if squeeze: