mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-18 13:54:20 +01:00
226 lines
6.3 KiB
Python
226 lines
6.3 KiB
Python
from st_keyup import st_keyup
|
|
from streamlit_helpers import *
|
|
|
|
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
|
|
|
|
VERSION2SPECS = {
|
|
"SDXL-Turbo": {
|
|
"H": 512,
|
|
"W": 512,
|
|
"C": 4,
|
|
"f": 8,
|
|
"is_legacy": False,
|
|
"config": "configs/inference/sd_xl_base.yaml",
|
|
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors",
|
|
},
|
|
}
|
|
|
|
|
|
class SubstepSampler(EulerAncestralSampler):
|
|
def __init__(self, n_sample_steps=1, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.n_sample_steps = n_sample_steps
|
|
self.steps_subset = [0, 100, 200, 300, 1000]
|
|
|
|
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
|
sigmas = self.discretization(
|
|
self.num_steps if num_steps is None else num_steps, device=self.device
|
|
)
|
|
sigmas = sigmas[
|
|
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:]
|
|
]
|
|
uc = cond
|
|
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
|
num_sigmas = len(sigmas)
|
|
s_in = x.new_ones([x.shape[0]])
|
|
return x, s_in, sigmas, num_sigmas, cond, uc
|
|
|
|
|
|
def seeded_randn(shape, seed):
|
|
randn = np.random.RandomState(seed).randn(*shape)
|
|
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32)
|
|
return randn
|
|
|
|
|
|
class SeededNoise:
|
|
def __init__(self, seed):
|
|
self.seed = seed
|
|
|
|
def __call__(self, x):
|
|
self.seed = self.seed + 1
|
|
return seeded_randn(x.shape, self.seed)
|
|
|
|
|
|
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
|
value_dict = {}
|
|
for key in keys:
|
|
if key == "txt":
|
|
value_dict["prompt"] = prompt
|
|
value_dict["negative_prompt"] = ""
|
|
|
|
if key == "original_size_as_tuple":
|
|
orig_width = init_dict["orig_width"]
|
|
orig_height = init_dict["orig_height"]
|
|
|
|
value_dict["orig_width"] = orig_width
|
|
value_dict["orig_height"] = orig_height
|
|
|
|
if key == "crop_coords_top_left":
|
|
crop_coord_top = 0
|
|
crop_coord_left = 0
|
|
|
|
value_dict["crop_coords_top"] = crop_coord_top
|
|
value_dict["crop_coords_left"] = crop_coord_left
|
|
|
|
if key == "aesthetic_score":
|
|
value_dict["aesthetic_score"] = 6.0
|
|
value_dict["negative_aesthetic_score"] = 2.5
|
|
|
|
if key == "target_size_as_tuple":
|
|
value_dict["target_width"] = init_dict["target_width"]
|
|
value_dict["target_height"] = init_dict["target_height"]
|
|
|
|
return value_dict
|
|
|
|
|
|
def sample(
|
|
model,
|
|
sampler,
|
|
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.",
|
|
H=1024,
|
|
W=1024,
|
|
seed=0,
|
|
filter=None,
|
|
):
|
|
F = 8
|
|
C = 4
|
|
shape = (1, C, H // F, W // F)
|
|
|
|
value_dict = init_embedder_options(
|
|
keys=get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
init_dict={
|
|
"orig_width": W,
|
|
"orig_height": H,
|
|
"target_width": W,
|
|
"target_height": H,
|
|
},
|
|
prompt=prompt,
|
|
)
|
|
|
|
if seed is None:
|
|
seed = torch.seed()
|
|
precision_scope = autocast
|
|
with torch.no_grad():
|
|
with precision_scope("cuda"):
|
|
batch, batch_uc = get_batch(
|
|
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
|
value_dict,
|
|
[1],
|
|
)
|
|
c = model.conditioner(batch)
|
|
uc = None
|
|
randn = seeded_randn(shape, seed)
|
|
|
|
def denoiser(input, sigma, c):
|
|
return model.denoiser(
|
|
model.model,
|
|
input,
|
|
sigma,
|
|
c,
|
|
)
|
|
|
|
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
|
samples_x = model.decode_first_stage(samples_z)
|
|
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
|
if filter is not None:
|
|
samples = filter(samples)
|
|
samples = (
|
|
(255 * samples)
|
|
.to(dtype=torch.uint8)
|
|
.permute(0, 2, 3, 1)
|
|
.detach()
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
return samples
|
|
|
|
|
|
def v_spacer(height) -> None:
|
|
for _ in range(height):
|
|
st.write("\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
st.title("Turbo")
|
|
|
|
head_cols = st.columns([1, 1, 1])
|
|
with head_cols[0]:
|
|
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
|
version_dict = VERSION2SPECS[version]
|
|
|
|
with head_cols[1]:
|
|
v_spacer(2)
|
|
if st.checkbox("Load Model"):
|
|
mode = "txt2img"
|
|
else:
|
|
mode = "skip"
|
|
|
|
if mode != "skip":
|
|
state = init_st(version_dict, load_filter=True)
|
|
if state["msg"]:
|
|
st.info(state["msg"])
|
|
model = state["model"]
|
|
load_model(model)
|
|
|
|
# seed
|
|
if "seed" not in st.session_state:
|
|
st.session_state.seed = 0
|
|
|
|
def increment_counter():
|
|
st.session_state.seed += 1
|
|
|
|
def decrement_counter():
|
|
if st.session_state.seed > 0:
|
|
st.session_state.seed -= 1
|
|
|
|
with head_cols[2]:
|
|
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4)
|
|
|
|
sampler = SubstepSampler(
|
|
n_sample_steps=1,
|
|
num_steps=1000,
|
|
eta=1.0,
|
|
discretization_config=dict(
|
|
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization"
|
|
),
|
|
)
|
|
sampler.n_sample_steps = n_steps
|
|
default_prompt = (
|
|
"A cinematic shot of a baby racoon wearing an intricate italian priest robe."
|
|
)
|
|
prompt = st_keyup(
|
|
"Enter a value", value=default_prompt, debounce=300, key="interactive_text"
|
|
)
|
|
|
|
cols = st.columns([1, 5, 1])
|
|
if mode != "skip":
|
|
with cols[0]:
|
|
v_spacer(14)
|
|
st.button("↩", on_click=decrement_counter)
|
|
with cols[2]:
|
|
v_spacer(14)
|
|
st.button("↪", on_click=increment_counter)
|
|
|
|
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
|
|
out = sample(
|
|
model,
|
|
sampler,
|
|
H=512,
|
|
W=512,
|
|
seed=st.session_state.seed,
|
|
prompt=prompt,
|
|
filter=state.get("filter"),
|
|
)
|
|
with cols[1]:
|
|
st.image(out[0])
|