import os import numpy as np import streamlit as st import torch from einops import rearrange, repeat from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, DPMPP2SAncestralSampler, EulerAncestralSampler, EulerEDMSampler, HeunEDMSampler, LinearMultistepSampler, ) from sgm.inference.helpers import ( Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, embed_watermark, ) from sgm.util import load_model_from_config @st.cache_resource() def init_st(version_dict, load_ckpt=True, load_filter=True): state = dict() if not "model" in state: config = version_dict["config"] ckpt = version_dict["ckpt"] config = OmegaConf.load(config) model = load_model_from_config( config, ckpt if load_ckpt else None, freeze=False ) state["model"] = model state["ckpt"] = ckpt if load_ckpt else None state["config"] = config if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def initial_model_load(model): global lowvram_mode if lowvram_mode: model.model.half() else: model.cuda() return model def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None): # Hardcoded demo settings; might undergo some changes in the future value_dict = {} for key in keys: if key == "txt": if prompt is None: prompt = st.text_input( "Prompt", "A professional photograph of an astronaut riding a pig" ) if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") value_dict["prompt"] = prompt value_dict["negative_prompt"] = negative_prompt if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", value=init_dict["orig_width"], min_value=16, ) orig_height = st.number_input( "orig_height", value=init_dict["orig_height"], min_value=16, ) value_dict["orig_width"] = orig_width value_dict["orig_height"] = orig_height if key == "crop_coords_top_left": crop_coord_top = st.number_input("crop_coords_top", value=0, min_value=0) crop_coord_left = st.number_input("crop_coords_left", value=0, min_value=0) value_dict["crop_coords_top"] = crop_coord_top value_dict["crop_coords_left"] = crop_coord_left if key == "aesthetic_score": value_dict["aesthetic_score"] = 6.0 value_dict["negative_aesthetic_score"] = 2.5 if key == "target_size_as_tuple": value_dict["target_width"] = init_dict["target_width"] value_dict["target_height"] = init_dict["target_height"] return value_dict def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) else: save_path = None return save_locally, save_path def show_samples(samples, outputs): if isinstance(samples, tuple): samples, _ = samples grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) def get_guider(key): guider = st.sidebar.selectbox( f"Discretization #{key}", [ "VanillaCFG", "IdentityGuider", ], ) if guider == "IdentityGuider": guider_config = { "target": "sgm.modules.diffusionmodules.guiders.IdentityGuider" } elif guider == "VanillaCFG": scale = st.number_input( f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0 ) thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ "None", ], ) if thresholder == "None": dyn_thresh_config = { "target": "sgm.modules.diffusionmodules.sampling_utils.NoDynamicThresholding" } else: raise NotImplementedError guider_config = { "target": "sgm.modules.diffusionmodules.guiders.VanillaCFG", "params": {"scale": scale, "dyn_thresh_config": dyn_thresh_config}, } else: raise NotImplementedError return guider_config def init_sampling( key=1, img2img_strength=1.0, specify_num_samples=True, stage2strength=None, ): num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) steps = st.sidebar.number_input( f"steps #{key}", value=40, min_value=1, max_value=1000 ) sampler = st.sidebar.selectbox( f"Sampler #{key}", [ "EulerEDMSampler", "HeunEDMSampler", "EulerAncestralSampler", "DPMPP2SAncestralSampler", "DPMPP2MSampler", "LinearMultistepSampler", ], 0, ) discretization = st.sidebar.selectbox( f"Discretization #{key}", [ "LegacyDDPMDiscretization", "EDMDiscretization", ], ) discretization_config = get_discretization(discretization, key=key) guider_config = get_guider(key=key) sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key) if img2img_strength < 1.0: st.warning( f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper" ) sampler.discretization = Img2ImgDiscretizationWrapper( sampler.discretization, strength=img2img_strength ) if stage2strength is not None: sampler.discretization = Txt2NoisyDiscretizationWrapper( sampler.discretization, strength=stage2strength, original_steps=steps ) return sampler, num_rows, num_cols def get_discretization(discretization, key=1): if discretization == "LegacyDDPMDiscretization": discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", } elif discretization == "EDMDiscretization": sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 sigma_max = st.number_input(f"sigma_max #{key}", value=14.61) # 14.6146 rho = st.number_input(f"rho #{key}", value=3.0) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.EDMDiscretization", "params": { "sigma_min": sigma_min, "sigma_max": sigma_max, "rho": rho, }, } return discretization_config def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1): if sampler_name == "EulerEDMSampler" or sampler_name == "HeunEDMSampler": s_churn = st.sidebar.number_input(f"s_churn #{key}", value=0.0, min_value=0.0) s_tmin = st.sidebar.number_input(f"s_tmin #{key}", value=0.0, min_value=0.0) s_tmax = st.sidebar.number_input(f"s_tmax #{key}", value=999.0, min_value=0.0) s_noise = st.sidebar.number_input(f"s_noise #{key}", value=1.0, min_value=0.0) if sampler_name == "EulerEDMSampler": sampler = EulerEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif sampler_name == "HeunEDMSampler": sampler = HeunEDMSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise, verbose=True, ) elif ( sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler" ): s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0) eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0) if sampler_name == "EulerAncestralSampler": sampler = EulerAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2SAncestralSampler": sampler = DPMPP2SAncestralSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, eta=eta, s_noise=s_noise, verbose=True, ) elif sampler_name == "DPMPP2MSampler": sampler = DPMPP2MSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, verbose=True, ) elif sampler_name == "LinearMultistepSampler": order = st.sidebar.number_input("order", value=4, min_value=1) sampler = LinearMultistepSampler( num_steps=steps, discretization_config=discretization_config, guider_config=guider_config, order=order, verbose=True, ) else: raise ValueError(f"unknown sampler {sampler_name}!") return sampler def get_interactive_image(key=None) -> Image.Image: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image def load_img(display=True, key=None): image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") transform = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0), ] ) img = transform(image)[None, ...] st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image