Files
generative-models/sgm/inference/helpers.py
Stephan Auerhahn 931d7a389a Add inference helpers & tests (#57)
* Add inference helpers & tests

* Support testing with hatch

* fixes to hatch script

* add inference test action

* change workflow trigger

* widen trigger to test

* revert changes to workflow triggers

* Install local python in action

* Trigger on push again

* fix python version

* add CODEOWNERS and change triggers

* Report tests results

* update action versions

* format

* Fix typo and add refiner helper

* use a shared path loaded from a secret for checkpoints source

* typo fix

* Use device from input and remove duplicated code

* PR feedback

* fix call to load_model_from_config

* Move model to gpu

* Refactor helpers

* cleanup

* test refiner, prep for 1.0, align with metadata

* fix paths on second load

* deduplicate streamlit code

* filenames

* fixes

* add pydantic to requirements

* fix usage of `msg` in demo script

* remove double text

* run black

* fix streamlit sampling when returning latents

* extract function for streamlit output

* another fix for streamlit outputs

* fix img2img in streamlit

* Make fp16 optional and fix device param

* PR feedback

* fix dict cast for dataclass

* run black, update ci script

* cache pip dependencies on hosted runners, remove extra runs

* install package in ci env

* fix cache path

* PR cleanup

* one more cleanup

* don't cache, it filled up
2023-07-26 04:37:24 -07:00

306 lines
10 KiB
Python

import os
from typing import Union, List, Optional
import math
import numpy as np
import torch
from PIL import Image
from einops import rearrange
from imwatermark import WatermarkEncoder
from omegaconf import ListConfig
from torch import autocast
from sgm.util import append_dims
class WatermarkEmbedder:
def __init__(self, watermark):
self.watermark = watermark
self.num_bits = len(WATERMARK_BITS)
self.encoder = WatermarkEncoder()
self.encoder.set_watermark("bits", self.watermark)
def __call__(self, image: torch.Tensor):
"""
Adds a predefined watermark to the input image
Args:
image: ([N,] B, C, H, W) in range [0, 1]
Returns:
same as input but watermarked
"""
# watermarking libary expects input as cv2 BGR format
squeeze = len(image.shape) == 4
if squeeze:
image = image[None, ...]
n = image.shape[0]
image_np = rearrange(
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
).numpy()[:, :, :, ::-1]
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
for k in range(image_np.shape[0]):
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
image = torch.from_numpy(
rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)
).to(image.device)
image = torch.clamp(image / 255, min=0.0, max=1.0)
if squeeze:
image = image[0]
return image
# A fixed 48-bit message that was choosen at random
# WATERMARK_MESSAGE = 0xB3EC907BB19E
WATERMARK_MESSAGE = 0b101100111110110010010000011110111011000110011110
# bin(x)[2:] gives bits of x as str, use int to convert them to 0/1
WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]
embed_watermark = WatermarkEmbedder(WATERMARK_BITS)
def get_unique_embedder_keys_from_conditioner(conditioner):
return list({x.input_key for x in conditioner.embedders})
def perform_save_locally(save_path, samples):
os.makedirs(os.path.join(save_path), exist_ok=True)
base_count = len(os.listdir(os.path.join(save_path)))
samples = embed_watermark(samples)
for sample in samples:
sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(sample.astype(np.uint8)).save(
os.path.join(save_path, f"{base_count:09}.png")
)
base_count += 1
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas
params:
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
"""
def __init__(self, discretization, strength: float = 1.0):
self.discretization = discretization
self.strength = strength
assert 0.0 <= self.strength <= 1.0
def __call__(self, *args, **kwargs):
# sigmas start large first, and decrease then
sigmas = self.discretization(*args, **kwargs)
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
sigmas = torch.flip(sigmas, (0,))
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
print("prune index:", max(int(self.strength * len(sigmas)), 1))
sigmas = torch.flip(sigmas, (0,))
print(f"sigmas after pruning: ", sigmas)
return sigmas
def do_sample(
model,
sampler,
value_dict,
num_samples,
H,
W,
C,
F,
force_uc_zero_embeddings: Optional[List] = None,
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device="cuda",
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
with torch.no_grad():
with autocast(device) as precision_scope:
with model.ema_scope():
num_samples = [num_samples]
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
num_samples,
)
for key in batch:
if isinstance(batch[key], torch.Tensor):
print(key, batch[key].shape)
elif isinstance(batch[key], list):
print(key, [len(l) for l in batch[key]])
else:
print(key, batch[key])
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
if not k == "crossattn":
c[k], uc[k] = map(
lambda y: y[k][: math.prod(num_samples)].to(device), (c, uc)
)
additional_model_inputs = {}
for k in batch2model_input:
additional_model_inputs[k] = batch[k]
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to(device)
def denoiser(input, sigma, c):
return model.denoiser(
model.model, input, sigma, c, **additional_model_inputs
)
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples
def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
# Hardcoded demo setups; might undergo some changes in the future
batch = {}
batch_uc = {}
for key in keys:
if key == "txt":
batch["txt"] = (
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
batch_uc["txt"] = (
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
.reshape(N)
.tolist()
)
elif key == "original_size_as_tuple":
batch["original_size_as_tuple"] = (
torch.tensor([value_dict["orig_height"], value_dict["orig_width"]])
.to(device)
.repeat(*N, 1)
)
elif key == "crop_coords_top_left":
batch["crop_coords_top_left"] = (
torch.tensor(
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
)
.to(device)
.repeat(*N, 1)
)
elif key == "aesthetic_score":
batch["aesthetic_score"] = (
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
)
batch_uc["aesthetic_score"] = (
torch.tensor([value_dict["negative_aesthetic_score"]])
.to(device)
.repeat(*N, 1)
)
elif key == "target_size_as_tuple":
batch["target_size_as_tuple"] = (
torch.tensor([value_dict["target_height"], value_dict["target_width"]])
.to(device)
.repeat(*N, 1)
)
else:
batch[key] = value_dict[key]
for key in batch.keys():
if key not in batch_uc and isinstance(batch[key], torch.Tensor):
batch_uc[key] = torch.clone(batch[key])
return batch, batch_uc
def get_input_image_tensor(image: Image.Image, device="cuda"):
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image_array = np.array(image.convert("RGB"))
image_array = image_array[None].transpose(0, 3, 1, 2)
image_tensor = torch.from_numpy(image_array).to(dtype=torch.float32) / 127.5 - 1.0
return image_tensor.to(device)
def do_img2img(
img,
model,
sampler,
value_dict,
num_samples,
force_uc_zero_embeddings=[],
additional_kwargs={},
offset_noise_level: float = 0.0,
return_latents=False,
skip_encode=False,
filter=None,
device="cuda",
):
with torch.no_grad():
with autocast(device) as precision_scope:
with model.ema_scope():
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
[num_samples],
)
c, uc = model.conditioner.get_unconditional_conditioning(
batch,
batch_uc=batch_uc,
force_uc_zero_embeddings=force_uc_zero_embeddings,
)
for k in c:
c[k], uc[k] = map(lambda y: y[k][:num_samples].to(device), (c, uc))
for k in additional_kwargs:
c[k] = uc[k] = additional_kwargs[k]
if skip_encode:
z = img
else:
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
sigmas = sampler.discretization(sampler.num_steps)
sigma = sigmas[0].to(z.device)
if offset_noise_level > 0.0:
noise = noise + offset_noise_level * append_dims(
torch.randn(z.shape[0], device=z.device), z.ndim
)
noised_z = z + noise * append_dims(sigma, z.ndim)
noised_z = noised_z / torch.sqrt(
1.0 + sigmas[0] ** 2.0
) # Note: hardcoded to DDPM-like scaling. need to generalize later.
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
if filter is not None:
samples = filter(samples)
if return_latents:
return samples, samples_z
return samples