diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 6f06218..b576669 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ class WatermarkEmbedder: if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ - :, :, :, ::-1 - ] + image_np = rearrange( + (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" + ).numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -185,7 +185,9 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) + return model.denoiser( + model.model, input, sigma, c, **additional_model_inputs + ) with ModelOnDevice(model.denoiser, device): with ModelOnDevice(model.model, device): @@ -212,10 +214,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) + .reshape(N) + .tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -225,7 +231,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) + torch.tensor( + [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] + ) .to(device) .repeat(*N, 1) ) @@ -234,7 +242,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]) + .to(device) + .repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -255,7 +265,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 + width, height = map( + lambda x: x - x % 64, (w, h) + ) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) @@ -342,7 +354,9 @@ def do_img2img( class ModelOnDevice(object): def __init__( - self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] + self, + model: Union[torch.nn.Module, torch.Tensor], + device: Union[torch.device, str], ): self.model = model self.device = torch.device(device)