mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-01-06 15:14:23 +01:00
SDXL-Turbo
This commit is contained in:
214
scripts/demo/turbo.py
Normal file
214
scripts/demo/turbo.py
Normal file
@@ -0,0 +1,214 @@
|
||||
from streamlit_helpers import *
|
||||
from st_keyup import st_keyup
|
||||
from sgm.modules.diffusionmodules.sampling import EulerAncestralSampler
|
||||
|
||||
VERSION2SPECS = {
|
||||
"SDXL-Turbo": {
|
||||
"H": 512,
|
||||
"W": 512,
|
||||
"C": 4,
|
||||
"f": 8,
|
||||
"is_legacy": False,
|
||||
"config": "configs/inference/sd_xl_base.yaml",
|
||||
"ckpt": "checkpoints/sd_xl_turbo_1.0.safetensors",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SubstepSampler(EulerAncestralSampler):
|
||||
def __init__(self, n_sample_steps=1, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.n_sample_steps = n_sample_steps
|
||||
self.steps_subset = [0, 100, 200, 300, 1000]
|
||||
|
||||
def prepare_sampling_loop(self, x, cond, uc=None, num_steps=None):
|
||||
sigmas = self.discretization(
|
||||
self.num_steps if num_steps is None else num_steps, device=self.device
|
||||
)
|
||||
sigmas = sigmas[
|
||||
self.steps_subset[: self.n_sample_steps] + self.steps_subset[-1:]
|
||||
]
|
||||
uc = cond
|
||||
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
|
||||
num_sigmas = len(sigmas)
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
return x, s_in, sigmas, num_sigmas, cond, uc
|
||||
|
||||
|
||||
def seeded_randn(shape, seed):
|
||||
randn = np.random.RandomState(seed).randn(*shape)
|
||||
randn = torch.from_numpy(randn).to(device="cuda", dtype=torch.float32)
|
||||
return randn
|
||||
|
||||
|
||||
class SeededNoise:
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, x):
|
||||
self.seed = self.seed + 1
|
||||
return seeded_randn(x.shape, self.seed)
|
||||
|
||||
|
||||
def init_embedder_options(keys, init_dict, prompt=None, negative_prompt=None):
|
||||
value_dict = {}
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
value_dict["prompt"] = prompt
|
||||
value_dict["negative_prompt"] = ""
|
||||
|
||||
if key == "original_size_as_tuple":
|
||||
orig_width = init_dict["orig_width"]
|
||||
orig_height = init_dict["orig_height"]
|
||||
|
||||
value_dict["orig_width"] = orig_width
|
||||
value_dict["orig_height"] = orig_height
|
||||
|
||||
if key == "crop_coords_top_left":
|
||||
crop_coord_top = 0
|
||||
crop_coord_left = 0
|
||||
|
||||
value_dict["crop_coords_top"] = crop_coord_top
|
||||
value_dict["crop_coords_left"] = crop_coord_left
|
||||
|
||||
if key == "aesthetic_score":
|
||||
value_dict["aesthetic_score"] = 6.0
|
||||
value_dict["negative_aesthetic_score"] = 2.5
|
||||
|
||||
if key == "target_size_as_tuple":
|
||||
value_dict["target_width"] = init_dict["target_width"]
|
||||
value_dict["target_height"] = init_dict["target_height"]
|
||||
|
||||
return value_dict
|
||||
|
||||
|
||||
def sample(
|
||||
model,
|
||||
sampler,
|
||||
prompt="A lush garden with oversized flowers and vibrant colors, inhabited by miniature animals.",
|
||||
H=1024,
|
||||
W=1024,
|
||||
seed=0,
|
||||
filter=None,
|
||||
):
|
||||
F = 8
|
||||
C = 4
|
||||
shape = (1, C, H // F, W // F)
|
||||
|
||||
value_dict = init_embedder_options(
|
||||
keys=get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
init_dict={
|
||||
"orig_width": W,
|
||||
"orig_height": H,
|
||||
"target_width": W,
|
||||
"target_height": H,
|
||||
},
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if seed is None:
|
||||
seed = torch.seed()
|
||||
precision_scope = autocast
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
[1],
|
||||
)
|
||||
c = model.conditioner(batch)
|
||||
uc = None
|
||||
randn = seeded_randn(shape, seed)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(
|
||||
model.model,
|
||||
input,
|
||||
sigma,
|
||||
c,
|
||||
)
|
||||
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
if filter is not None:
|
||||
samples = filter(samples)
|
||||
samples = (
|
||||
(255 * samples)
|
||||
.to(dtype=torch.uint8)
|
||||
.permute(0, 2, 3, 1)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
return samples
|
||||
|
||||
|
||||
def v_spacer(height) -> None:
|
||||
for _ in range(height):
|
||||
st.write("\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
st.title("Turbo")
|
||||
|
||||
head_cols = st.columns([1, 1, 1])
|
||||
with head_cols[0]:
|
||||
version = st.selectbox("Model Version", list(VERSION2SPECS.keys()), 0)
|
||||
version_dict = VERSION2SPECS[version]
|
||||
|
||||
with head_cols[1]:
|
||||
v_spacer(2)
|
||||
if st.checkbox("Load Model"):
|
||||
mode = "txt2img"
|
||||
else:
|
||||
mode = "skip"
|
||||
|
||||
if mode != "skip":
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
load_model(model)
|
||||
|
||||
# seed
|
||||
if "seed" not in st.session_state:
|
||||
st.session_state.seed = 0
|
||||
|
||||
def increment_counter():
|
||||
st.session_state.seed += 1
|
||||
|
||||
def decrement_counter():
|
||||
if st.session_state.seed > 0:
|
||||
st.session_state.seed -= 1
|
||||
|
||||
with head_cols[2]:
|
||||
n_steps = st.number_input(label="number of steps", min_value=1, max_value=4)
|
||||
|
||||
sampler = SubstepSampler(
|
||||
n_sample_steps=1,
|
||||
num_steps=1000,
|
||||
eta=1.0,
|
||||
discretization_config=dict(
|
||||
target="sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization"
|
||||
),
|
||||
)
|
||||
sampler.n_sample_steps = n_steps
|
||||
default_prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
|
||||
prompt = st_keyup("Enter a value", value=default_prompt, debounce=300, key="interactive_text")
|
||||
|
||||
cols = st.columns([1, 5, 1])
|
||||
if mode != "skip":
|
||||
with cols[0]:
|
||||
v_spacer(14)
|
||||
st.button("↩", on_click=decrement_counter)
|
||||
with cols[2]:
|
||||
v_spacer(14)
|
||||
st.button("↪", on_click=increment_counter)
|
||||
|
||||
sampler.noise_sampler = SeededNoise(seed=st.session_state.seed)
|
||||
out = sample(
|
||||
model, sampler, H=512, W=512, seed=st.session_state.seed, prompt=prompt, filter=state.get("filter")
|
||||
)
|
||||
with cols[1]:
|
||||
st.image(out[0])
|
||||
Reference in New Issue
Block a user