diff --git a/README.md b/README.md index 7e8d3d5..9293884 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,9 @@ ## News +**July 4, 2023** +- A technical report on SDXL is now available [here](assets/sdxl_report.pdf). + **June 22, 2023** diff --git a/assets/sdxl_report.pdf b/assets/sdxl_report.pdf new file mode 100644 index 0000000..577ef32 Binary files /dev/null and b/assets/sdxl_report.pdf differ diff --git a/scripts/demo/detect.py b/scripts/demo/detect.py index 823ae8d..96e9f21 100644 --- a/scripts/demo/detect.py +++ b/scripts/demo/detect.py @@ -83,7 +83,7 @@ class GetWatermarkMatch: def __call__(self, x: np.ndarray) -> np.ndarray: """ Detects the number of matching bits the predefined watermark with one - or multiple images. Images should be in cv2 format, e.g. h x w x c. + or multiple images. Images should be in cv2 format, e.g. h x w x c BGR. Args: x: ([B], h w, c) in range [0, 255] @@ -94,7 +94,6 @@ class GetWatermarkMatch: squeeze = len(x.shape) == 3 if squeeze: x = x[None, ...] - x = np.flip(x, axis=-1) bs = x.shape[0] detected = np.empty((bs, self.num_bits), dtype=bool) diff --git a/scripts/demo/sampling.py b/scripts/demo/sampling.py index 7e953e9..98d0af3 100644 --- a/scripts/demo/sampling.py +++ b/scripts/demo/sampling.py @@ -311,8 +311,9 @@ if __name__ == "__main__": samples, samples_z = out else: samples = out + samples_z = None - if add_pipeline: + if add_pipeline and samples_z is not None: st.write("**Running Refinement Stage**") samples = apply_refiner( samples_z, diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index dfb0056..8f53b5d 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -43,19 +43,19 @@ class WatermarkEmbedder: Returns: same as input but watermarked """ - # watermarking libary expects input as cv2 format + # watermarking libary expects input as cv2 BGR format squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange( (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy() + ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy( - rearrange(image_np, "(n b) h w c -> n b c h w", n=n) + rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n) ).to(image.device) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: