fix device check

This commit is contained in:
Stephan Auerhahn
2023-08-06 12:26:01 +00:00
parent ea5f232d5d
commit 0c2c5c66a2

View File

@@ -35,9 +35,9 @@ class WatermarkEmbedder:
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
:, :, :, ::-1
]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
@@ -185,9 +185,7 @@ def do_sample(
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
@@ -214,14 +212,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
@@ -231,9 +225,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
)
@@ -242,9 +234,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
@@ -265,9 +255,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)
@@ -353,10 +341,24 @@ def do_img2img(
class ModelOnDevice(object):
def __init__(self, model, device):
def __init__(
self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
):
self.model = model
self.device = device
self.original_device = model.device
self.device = torch.device(device)
if isinstance(model, torch.Tensor):
self.original_device = model.device
else:
param = next(model.parameters(), None)
if param is not None:
self.original_device = param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
self.original_device = buf.device
else:
# If device could not be found, turn this into a no-op
self.original_device = self.device
def __enter__(self):
if self.device != self.original_device:
@@ -367,11 +369,3 @@ class ModelOnDevice(object):
self.model.to(self.original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model(model, device):
if model.device != device:
old_device = model.device
model.to(device)
return old_device
return False