diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 30dc082..6f06218 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -35,9 +35,9 @@ class WatermarkEmbedder: if squeeze: image = image[None, ...] n = image.shape[0] - image_np = rearrange( - (255 * image).detach().cpu(), "n b c h w -> (n b) h w c" - ).numpy()[:, :, :, ::-1] + image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[ + :, :, :, ::-1 + ] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") @@ -185,9 +185,7 @@ def do_sample( randn = torch.randn(shape).to(device) def denoiser(input, sigma, c): - return model.denoiser( - model.model, input, sigma, c, **additional_model_inputs - ) + return model.denoiser(model.model, input, sigma, c, **additional_model_inputs) with ModelOnDevice(model.denoiser, device): with ModelOnDevice(model.model, device): @@ -214,14 +212,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): for key in keys: if key == "txt": batch["txt"] = ( - np.repeat([value_dict["prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) batch_uc["txt"] = ( - np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)) - .reshape(N) - .tolist() + np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist() ) elif key == "original_size_as_tuple": batch["original_size_as_tuple"] = ( @@ -231,9 +225,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): ) elif key == "crop_coords_top_left": batch["crop_coords_top_left"] = ( - torch.tensor( - [value_dict["crop_coords_top"], value_dict["crop_coords_left"]] - ) + torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]]) .to(device) .repeat(*N, 1) ) @@ -242,9 +234,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1) ) batch_uc["aesthetic_score"] = ( - torch.tensor([value_dict["negative_aesthetic_score"]]) - .to(device) - .repeat(*N, 1) + torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1) ) elif key == "target_size_as_tuple": @@ -265,9 +255,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"): def get_input_image_tensor(image: Image.Image, device="cuda"): w, h = image.size print(f"loaded input image of size ({w}, {h})") - width, height = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 + width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64 image = image.resize((width, height)) image_array = np.array(image.convert("RGB")) image_array = image_array[None].transpose(0, 3, 1, 2) @@ -353,10 +341,24 @@ def do_img2img( class ModelOnDevice(object): - def __init__(self, model, device): + def __init__( + self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] + ): self.model = model - self.device = device - self.original_device = model.device + self.device = torch.device(device) + if isinstance(model, torch.Tensor): + self.original_device = model.device + else: + param = next(model.parameters(), None) + if param is not None: + self.original_device = param.device + else: + buf = next(model.buffers(), None) + if buf is not None: + self.original_device = buf.device + else: + # If device could not be found, turn this into a no-op + self.original_device = self.device def __enter__(self): if self.device != self.original_device: @@ -367,11 +369,3 @@ class ModelOnDevice(object): self.model.to(self.original_device) if torch.cuda.is_available(): torch.cuda.empty_cache() - - -def load_model(model, device): - if model.device != device: - old_device = model.device - model.to(device) - return old_device - return False