diff --git a/scripts/demo/detect.py b/scripts/demo/detect.py index 823ae8d..96e9f21 100644 --- a/scripts/demo/detect.py +++ b/scripts/demo/detect.py @@ -83,7 +83,7 @@ class GetWatermarkMatch: def __call__(self, x: np.ndarray) -> np.ndarray: """ Detects the number of matching bits the predefined watermark with one - or multiple images. Images should be in cv2 format, e.g. h x w x c. + or multiple images. Images should be in cv2 format, e.g. h x w x c BGR. Args: x: ([B], h w, c) in range [0, 255] @@ -94,7 +94,6 @@ class GetWatermarkMatch: squeeze = len(x.shape) == 3 if squeeze: x = x[None, ...] - x = np.flip(x, axis=-1) bs = x.shape[0] detected = np.empty((bs, self.num_bits), dtype=bool) diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index dfb0056..f459997 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -43,19 +43,19 @@ class WatermarkEmbedder: Returns: same as input but watermarked """ - # watermarking libary expects input as cv2 format + # watermarking libary expects input as cv2 BGR format squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy() + ).numpy()[:,:,:,::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( - rearrange(image_np, "(n b) h w c -> n b c h w", n=n) + rearrange(image_np[:,:,:,::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: