mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
format
This commit is contained in:
@@ -35,9 +35,9 @@ class WatermarkEmbedder:
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
|
||||
:, :, :, ::-1
|
||||
]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
@@ -185,7 +185,9 @@ def do_sample(
|
||||
randn = torch.randn(shape).to(device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
with ModelOnDevice(model.denoiser, device):
|
||||
with ModelOnDevice(model.model, device):
|
||||
@@ -212,10 +214,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
@@ -225,7 +231,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
@@ -234,7 +242,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
@@ -255,7 +265,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image_array = np.array(image.convert("RGB"))
|
||||
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||
@@ -342,7 +354,9 @@ def do_img2img(
|
||||
|
||||
class ModelOnDevice(object):
|
||||
def __init__(
|
||||
self, model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
|
||||
self,
|
||||
model: Union[torch.nn.Module, torch.Tensor],
|
||||
device: Union[torch.device, str],
|
||||
):
|
||||
self.model = model
|
||||
self.device = torch.device(device)
|
||||
|
||||
Reference in New Issue
Block a user