import os import numpy as np import streamlit as st import torch from einops import rearrange, repeat from PIL import Image from torchvision import transforms from typing import Optional, Tuple, Dict, Any from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering from sgm.inference.api import ( Discretization, Guider, Sampler, SamplingParams, SamplingSpec, SamplingPipeline, Thresholder, ) from sgm.inference.helpers import ( embed_watermark, ) @st.cache_resource() def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, Any]: global lowvram_mode state: Dict[str, Any] = dict() if not "model" in state: config = spec.config ckpt = spec.ckpt pipeline = SamplingPipeline( model_spec=spec, use_fp16=lowvram_mode, device="cpu" if lowvram_mode else "cuda", ) state["spec"] = spec state["model"] = pipeline state["ckpt"] = ckpt if load_ckpt else None state["config"] = config state["params"] = SamplingParams() if load_filter: state["filter"] = DeepFloydDataFiltering(verbose=False) return state lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def get_unique_embedder_keys_from_conditioner(conditioner): return list(set([x.input_key for x in conditioner.embedders])) def init_embedder_options( keys, params: SamplingParams, prompt=None, negative_prompt=None ) -> SamplingParams: for key in keys: if key == "txt": if prompt is None: prompt = st.text_input( "Prompt", "A professional photograph of an astronaut riding a pig" ) if negative_prompt is None: negative_prompt = st.text_input("Negative prompt", "") if key == "original_size_as_tuple": orig_width = st.number_input( "orig_width", value=params.orig_width, min_value=16, ) orig_height = st.number_input( "orig_height", value=params.orig_height, min_value=16, ) params.orig_width = int(orig_width) params.orig_height = int(orig_height) if key == "crop_coords_top_left": crop_coord_top = st.number_input( "crop_coords_top", value=params.crop_coords_top, min_value=0 ) crop_coord_left = st.number_input( "crop_coords_left", value=params.crop_coords_left, min_value=0 ) params.crop_coords_top = int(crop_coord_top) params.crop_coords_left = int(crop_coord_left) return params def perform_save_locally(save_path, samples): os.makedirs(os.path.join(save_path), exist_ok=True) base_count = len(os.listdir(os.path.join(save_path))) samples = embed_watermark(samples) for sample in samples: sample = 255.0 * rearrange(sample.cpu().numpy(), "c h w -> h w c") Image.fromarray(sample.astype(np.uint8)).save( os.path.join(save_path, f"{base_count:09}.png") ) base_count += 1 def init_save_locally(_dir, init_value: bool = False): save_locally = st.sidebar.checkbox("Save images locally", value=init_value) if save_locally: save_path = st.text_input("Save path", value=os.path.join(_dir, "samples")) else: save_path = None return save_locally, save_path def show_samples(samples, outputs): if isinstance(samples, tuple): samples, _ = samples grid = embed_watermark(torch.stack([samples])) grid = rearrange(grid, "n b c h w -> (n h) (b w) c") outputs.image(grid.cpu().numpy()) def get_guider(key, params: SamplingParams) -> SamplingParams: params.guider = Guider( st.sidebar.selectbox( f"Discretization #{key}", [member.value for member in Guider] ) ) if params.guider == Guider.VANILLA: scale = st.number_input( f"cfg-scale #{key}", value=params.scale, min_value=0.0, max_value=100.0 ) params.scale = scale thresholder = st.sidebar.selectbox( f"Thresholder #{key}", [ "None", ], ) if thresholder == "None": params.thresholder = Thresholder.NONE else: raise NotImplementedError return params def init_sampling( key=1, params: SamplingParams = SamplingParams(), specify_num_samples=True, ) -> Tuple[SamplingParams, int, int]: params = SamplingParams(img2img_strength=params.img2img_strength) num_rows, num_cols = 1, 1 if specify_num_samples: num_cols = st.number_input( f"num cols #{key}", value=2, min_value=1, max_value=10 ) params.steps = int( st.sidebar.number_input( f"steps #{key}", value=params.steps, min_value=1, max_value=1000 ) ) params.sampler = Sampler( st.sidebar.selectbox( f"Sampler #{key}", [member.value for member in Sampler], 0, ) ) params.discretization = Discretization( st.sidebar.selectbox( f"Discretization #{key}", [member.value for member in Discretization], ) ) params = get_discretization(params, key=key) params = get_guider(key=key, params=params) params = get_sampler(params, key=key) return params, num_rows, num_cols def get_discretization(params: SamplingParams, key=1) -> SamplingParams: if params.discretization == Discretization.EDM: params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min) params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max) params.rho = st.number_input(f"rho #{key}", value=params.rho) return params def get_sampler(params: SamplingParams, key=1) -> SamplingParams: if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM: params.s_churn = st.sidebar.number_input( f"s_churn #{key}", value=params.s_churn, min_value=0.0 ) params.s_tmin = st.sidebar.number_input( f"s_tmin #{key}", value=params.s_tmin, min_value=0.0 ) params.s_tmax = st.sidebar.number_input( f"s_tmax #{key}", value=params.s_tmax, min_value=0.0 ) params.s_noise = st.sidebar.number_input( f"s_noise #{key}", value=params.s_noise, min_value=0.0 ) elif ( params.sampler == Sampler.EULER_ANCESTRAL or params.sampler == Sampler.DPMPP2S_ANCESTRAL ): params.s_noise = st.sidebar.number_input( "s_noise", value=params.s_noise, min_value=0.0 ) params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0) elif params.sampler == Sampler.LINEAR_MULTISTEP: params.order = int( st.sidebar.number_input("order", value=params.order, min_value=1) ) return params def get_interactive_image(key=None) -> Optional[Image.Image]: image = st.file_uploader("Input", type=["jpg", "JPEG", "png"], key=key) if image is not None: image = Image.open(image) if not image.mode == "RGB": image = image.convert("RGB") return image return None def load_img(display=True, key=None) -> Optional[torch.Tensor]: image = get_interactive_image(key=key) if image is None: return None if display: st.image(image) w, h = image.size print(f"loaded input image of size ({w}, {h})") transform = transforms.Compose( [ transforms.ToTensor(), transforms.Lambda(lambda x: x * 2.0 - 1.0), ] ) img = transform(image)[None, ...] st.text(f"input min/max/mean: {img.min():.3f}/{img.max():.3f}/{img.mean():.3f}") return img def get_init_img(batch_size=1, key=None): init_image = load_img(key=key).cuda() init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) return init_image