mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
fix device check
This commit is contained in:
@@ -35,9 +35,9 @@ class WatermarkEmbedder:
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
|
||||
:, :, :, ::-1
|
||||
]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
@@ -185,9 +185,7 @@ def do_sample(
|
||||
randn = torch.randn(shape).to(device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
|
||||
|
||||
with ModelOnDevice(model.denoiser, device):
|
||||
with ModelOnDevice(model.model, device):
|
||||
@@ -214,14 +212,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
@@ -231,9 +225,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
@@ -242,9 +234,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
@@ -265,9 +255,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image_array = np.array(image.convert("RGB"))
|
||||
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||
@@ -353,10 +341,24 @@ def do_img2img(
|
||||
|
||||
|
||||
class ModelOnDevice(object):
|
||||
def __init__(self, model, device):
|
||||
def __init__(
|
||||
self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
|
||||
):
|
||||
self.model = model
|
||||
self.device = device
|
||||
self.original_device = model.device
|
||||
self.device = torch.device(device)
|
||||
if isinstance(model, torch.Tensor):
|
||||
self.original_device = model.device
|
||||
else:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
self.original_device = param.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
self.original_device = buf.device
|
||||
else:
|
||||
# If device could not be found, turn this into a no-op
|
||||
self.original_device = self.device
|
||||
|
||||
def __enter__(self):
|
||||
if self.device != self.original_device:
|
||||
@@ -367,11 +369,3 @@ class ModelOnDevice(object):
|
||||
self.model.to(self.original_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_model(model, device):
|
||||
if model.device != device:
|
||||
old_device = model.device
|
||||
model.to(device)
|
||||
return old_device
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user