formatting, remove reference

This commit is contained in:
Stephan Auerhahn
2023-08-06 11:30:40 +00:00
parent b216934b7e
commit f06c67c206
4 changed files with 74 additions and 29 deletions

View File

@@ -105,7 +105,9 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -193,7 +195,9 @@ def run_img2img(
prompt=prompt,
negative_prompt=negative_prompt,
)
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
strength = st.number_input(
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
)
sampler, num_rows, num_cols = init_sampling(
img2img_strength=strength,
stage2strength=stage2strength,
@@ -280,8 +284,6 @@ if __name__ == "__main__":
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
state = init_st(version_dict, load_filter=True)
if state["msg"]:
st.info(state["msg"])
model = state["model"]
is_legacy = version_dict["is_legacy"]
@@ -308,7 +310,6 @@ if __name__ == "__main__":
version_dict2 = VERSION2SPECS[version2]
state2 = init_st(version_dict2, load_filter=False)
st.info(state2["msg"])
stage2strength = st.number_input(
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0

View File

@@ -34,7 +34,9 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
ckpt = version_dict["ckpt"]
config = OmegaConf.load(config)
model = load_model_from_config(config, ckpt if load_ckpt else None, freeze=False)
model = load_model_from_config(
config, ckpt if load_ckpt else None, freeze=False
)
state["model"] = model
state["ckpt"] = ckpt if load_ckpt else None
@@ -154,9 +156,13 @@ def get_guider(key):
)
if guider == "IdentityGuider":
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif guider == "VanillaCFG":
scale = st.number_input(f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0)
scale = st.number_input(
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
)
thresholder = st.sidebar.selectbox(
f"Thresholder #{key}",
@@ -189,9 +195,13 @@ def init_sampling(
):
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
steps = st.sidebar.number_input(f"steps #{key}", value=40, min_value=1, max_value=1000)
steps = st.sidebar.number_input(
f"steps #{key}", value=40, min_value=1, max_value=1000
)
sampler = st.sidebar.selectbox(
f"Sampler #{key}",
[
@@ -218,7 +228,9 @@ def init_sampling(
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
if img2img_strength < 1.0:
st.warning(f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper")
st.warning(
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
)
sampler.discretization = Img2ImgDiscretizationWrapper(
sampler.discretization, strength=img2img_strength
)
@@ -279,7 +291,10 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
s_noise=s_noise,
verbose=True,
)
elif sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler":
elif (
sampler_name == "EulerAncestralSampler"
or sampler_name == "DPMPP2SAncestralSampler"
):
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)

View File

@@ -181,20 +181,30 @@ class SamplingPipeline:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if not os.path.exists(model_path):
# This supports development installs where checkpoints is root level of the repo
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
model_path = (
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "checkpoints"
)
if config_path is None:
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
)
if not os.path.exists(config_path):
# This supports development installs where configs is root level of the repo
config_path = (
pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
pathlib.Path(__file__).parent.parent.parent.resolve()
/ "configs/inference"
)
self.config = str(config_path / self.specs.config)
self.ckpt = str(model_path / self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)
@@ -221,7 +231,9 @@ class SamplingPipeline:
sampler = get_sampler_config(params)
if stage2strength is not None:
sampler.discretization = Txt2NoisyDiscretizationWrapper(
sampler.discretization, strength=stage2strength, original_steps=params.steps
sampler.discretization,
strength=stage2strength,
original_steps=params.steps,
)
value_dict = asdict(params)
value_dict["prompt"] = prompt
@@ -275,7 +287,10 @@ class SamplingPipeline:
)
def wrap_discretization(self, discretization, strength=1.0):
if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0:
if (
not isinstance(discretization, Img2ImgDiscretizationWrapper)
and strength < 1.0
):
return Img2ImgDiscretizationWrapper(discretization, strength=strength)
return discretization
@@ -322,7 +337,9 @@ class SamplingPipeline:
def get_guider_config(params: SamplingParams):
if params.guider == Guider.IDENTITY:
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
guider_config = {
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
}
elif params.guider == Guider.VANILLA:
scale = params.scale

View File

@@ -35,9 +35,9 @@ class WatermarkEmbedder:
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
:, :, :, ::-1
]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
@@ -184,7 +184,9 @@ def do_sample(
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
@@ -211,10 +213,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
@@ -224,7 +230,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
@@ -233,7 +241,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
@@ -254,7 +264,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)