diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index f459997..8f53b5d 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -50,12 +50,12 @@ class WatermarkEmbedder: n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:,:,:,::-1] + ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( - rearrange(image_np[:,:,:,::-1], "(n b) h w c -> n b c h w", n=n) + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: