align with streamlit helpers and re-de-deuplicate

This commit is contained in:
Stephan Auerhahn
2023-08-06 11:20:22 +00:00
parent 77d0e27747
commit b216934b7e
4 changed files with 140 additions and 468 deletions

View File

@@ -1,5 +1,11 @@
from pytorch_lightning import seed_everything
from sgm.inference.helpers import (
do_img2img,
do_sample,
get_unique_embedder_keys_from_conditioner,
perform_save_locally,
)
from scripts.demo.streamlit_helpers import *
SAVE_PATH = "outputs/demo/txt2img/"
@@ -99,9 +105,7 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -143,6 +147,8 @@ def run_txt2img(
if st.button("Sample"):
st.write(f"**Model I:** {version}")
outputs = st.empty()
st.text("Sampling")
out = do_sample(
state["model"],
sampler,
@@ -156,6 +162,9 @@ def run_txt2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -184,9 +193,7 @@ def run_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
@@ -194,6 +201,8 @@ def run_img2img(
num_samples = num_rows * num_cols
if st.button("Sample"):
outputs = st.empty()
st.text("Sampling")
out = do_img2img(
repeat(img, "1 ... -> n ...", n=num_samples),
state["model"],
@@ -204,6 +213,7 @@ def run_img2img(
return_latents=return_latents,
filter=filter,
)
show_samples(out, outputs)
return out
@@ -342,6 +352,7 @@ if __name__ == "__main__":
samples_z = None
if add_pipeline and samples_z is not None:
outputs = st.empty()
st.write("**Running Refinement Stage**")
samples = apply_refiner(
samples_z,
@@ -353,6 +364,7 @@ if __name__ == "__main__":
filter=state.get("filter"),
finish_denoising=finish_denoising,
)
show_samples(samples, outputs)
if save_locally and samples is not None:
perform_save_locally(save_path, samples)

View File

@@ -1,18 +1,13 @@
import math
import os
from typing import List, Union
import numpy as np
import streamlit as st
import torch
from einops import rearrange, repeat
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig, OmegaConf
from omegaconf import OmegaConf
from PIL import Image
from safetensors.torch import load_file as load_safetensors
from torch import autocast
from torchvision import transforms
from torchvision.utils import make_grid
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
from sgm.modules.diffusionmodules.sampling import (
@@ -23,52 +18,12 @@ from sgm.modules.diffusionmodules.sampling import (
HeunEDMSampler,
LinearMultistepSampler,
)
from sgm.util import append_dims, instantiate_from_config
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watemark = WatermarkEmbedder(WATERMARK_BITS)
from sgm.inference.helpers import (
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
embed_watermark,
)
from sgm.util import load_model_from_config
@st.cache_resource()
@@ -79,9 +34,8 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
ckpt = version_dict["ckpt"]
config = OmegaConf.load(config)
model, msg = load_model_from_config(config, ckpt if load_ckpt else None)
model = load_model_from_config(config, ckpt if load_ckpt else None, freeze=False)
state["msg"] = msg
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
@@ -90,10 +44,6 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
return state
def load_model(model):
model.cuda()
lowvram_mode = False
@@ -111,48 +61,6 @@ def initial_model_load(model):
return model
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def load_model_from_config(config, ckpt=None, verbose=True):
model = instantiate_from_config(config.model)
if ckpt is not None:
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
global_step = pl_sd["global_step"]
st.info(f"loaded ckpt from global step {global_step}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
else:
raise NotImplementedError
msg = None
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
else:
msg = None
model = initial_model_load(model)
model.eval()
return model, msg
def get_unique_embedder_keys_from_conditioner(conditioner):
return list(set([x.input_key for x in conditioner.embedders]))
@@ -209,7 +117,7 @@ def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watemark(samples)
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
@@ -228,58 +136,12 @@ def init_save_locally(_dir, init_value: bool = False):
return save_locally, save_path
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def show_samples(samples, outputs):
if isinstance(samples, tuple):
samples, _ = samples
grid = embed_watermark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
def get_guider(key):
@@ -292,13 +154,9 @@ def get_guider(key):
)
if guider == "IdentityGuider":
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
elif guider == "VanillaCFG":
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
)
scale = st.number_input(f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0)
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
@@ -331,13 +189,9 @@ def init_sampling(
):
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
)
steps = st.sidebar.number_input(f"steps #{key}", value=40, min_value=1, max_value=1000)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
@@ -364,9 +218,7 @@ def init_sampling(
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
st.warning(f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper")
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
@@ -427,10 +279,7 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
s_noise=s_noise,
verbose=True,
)
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
elif sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler":
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
@@ -507,238 +356,3 @@ def get_init_img(batch_size=1, key=None):
init_image = load_img(key=key).cuda()
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
return init_image
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: List = None,
batch2model_input: List = None,
return_latents=False,
filter=None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
num_samples = [num_samples]
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
unload_model(model.first_stage_model)
if filter is not None:
samples = filter(samples)
grid = torch.stack([samples])
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
@torch.no_grad()
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: int = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
):
st.text("Sampling")
outputs = st.empty()
precision_scope = autocast
with torch.no_grad():
with precision_scope("cuda"):
with model.ema_scope():
load_model(model.conditioner)
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
unload_model(model.conditioner)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to("cuda"), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
load_model(model.first_stage_model)
z = model.encode_first_stage(img)
unload_model(model.first_stage_model)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps).cuda()
sigma = sigmas[0]
st.info(f"all sigmas: {sigmas}")
st.info(f"noising sigma: {sigma}")
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
load_model(model.denoiser)
load_model(model.model)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
unload_model(model.model)
unload_model(model.denoiser)
load_model(model.first_stage_model)
samples_x = model.decode_first_stage(samples_z)
unload_model(model.first_stage_model)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
grid = embed_watemark(torch.stack([samples]))
grid = rearrange(grid, "n b c h w -> (n h) (b w) c")
outputs.image(grid.cpu().numpy())
if return_latents:
return samples, samples_z
return samples

View File

@@ -7,6 +7,7 @@ from sgm.inference.helpers import (
do_sample,
do_img2img,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
from sgm.modules.diffusionmodules.sampling import (
EulerEDMSampler,
@@ -180,30 +181,20 @@ class SamplingPipeline:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if not os.path.exists(model_path):
# This supports development installs where checkpoints is root level of the repo
model_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "checkpoints"
)
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
if config_path is None:
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
)
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
if not os.path.exists(config_path):
# This supports development installs where configs is root level of the repo
config_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "configs/inference"
pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
)
self.config = str(config_path / self.specs.config)
self.ckpt = str(model_path / self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
if not os.path.exists(self.ckpt):
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
@@ -225,8 +216,13 @@ class SamplingPipeline:
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
stage2strength=None,
):
sampler = get_sampler_config(params)
if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=params.steps
)
value_dict = asdict(params)
value_dict["prompt"] = prompt
value_dict["negative_prompt"] = negative_prompt
@@ -279,10 +275,7 @@ class SamplingPipeline:
)
def wrap_discretization(self, discretization, strength=1.0):
if (
not isinstance(discretization, Img2ImgDiscretizationWrapper)
and strength < 1.0
):
if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0:
return Img2ImgDiscretizationWrapper(discretization, strength=strength)
return discretization
@@ -329,9 +322,7 @@ class SamplingPipeline:
def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
elif params.guider == Guider.VANILLA:
scale = params.scale

View File

@@ -35,9 +35,9 @@ class WatermarkEmbedder:
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
:, :, :, ::-1
]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
@@ -98,6 +98,36 @@ class Img2ImgDiscretizationWrapper:
return sigmas
class Txt2NoisyDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 0.0, original_steps=None):
self.discretization = discretization
self.strength = strength
self.original_steps = original_steps
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
if self.original_steps is None:
steps = len(sigmas)
else:
steps = self.original_steps + 1
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
sigmas = sigmas[prune_index:]
print("prune index:", prune_index)
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample(
model,
sampler,
@@ -154,13 +184,15 @@ def do_sample(
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with ModelOnDevice(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
@@ -179,14 +211,10 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
@@ -196,9 +224,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
.to(device)
.repeat(*N, 1)
)
@@ -207,9 +233,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
)
elif key == "target_size_as_tuple":
@@ -230,9 +254,7 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)
@@ -252,10 +274,11 @@ def do_img2img(
return_latents=False,
skip_encode=False,
filter=None,
add_noise=True,
device="cuda",
):
with torch.no_grad():
with autocast(device) as precision_scope:
with autocast(device):
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
@@ -285,17 +308,24 @@ def do_img2img(
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
if add_noise:
noised_z = z + noise * append_dims(sigma, z.ndim).cuda()
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
else:
noised_z = z / torch.sqrt(1.0 + sigmas[0] ** 2.0)
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with ModelOnDevice(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
@@ -303,3 +333,28 @@ def do_img2img(
if return_latents:
return samples, samples_z
return samples
class ModelOnDevice(object):
def __init__(self, model, device):
self.model = model
self.device = device
self.original_device = model.device
def __enter__(self):
if self.device != self.original_device:
self.model.to(self.device)
def __exit__(self, *args):
if self.device != self.original_device:
self.model.to(self.original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
def load_model(model, device):
if model.device != device:
old_device = model.device
model.to(device)
return old_device
return False