mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
PR fixes, model specific defaults
This commit is contained in:
@@ -66,9 +66,7 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
@@ -78,26 +76,19 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
|
||||
def run_txt2img(
|
||||
state,
|
||||
version: str,
|
||||
model_id: ModelArchitecture,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
return_latents=False,
|
||||
stage2strength=None,
|
||||
):
|
||||
spec: SamplingSpec = state["spec"]
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
|
||||
width, height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
if model_id in sdxl_base_model_list:
|
||||
width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
else:
|
||||
height = int(
|
||||
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
|
||||
)
|
||||
width = int(
|
||||
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
|
||||
)
|
||||
height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048))
|
||||
width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048))
|
||||
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
@@ -207,12 +198,12 @@ def apply_refiner(
|
||||
|
||||
|
||||
sdxl_base_model_list = [
|
||||
ModelArchitecture.SDXL_V1_BASE,
|
||||
ModelArchitecture.SDXL_V1_0_BASE,
|
||||
ModelArchitecture.SDXL_V0_9_BASE,
|
||||
]
|
||||
|
||||
sdxl_refiner_model_list = [
|
||||
ModelArchitecture.SDXL_V1_REFINER,
|
||||
ModelArchitecture.SDXL_V1_0_REFINER,
|
||||
ModelArchitecture.SDXL_V0_9_REFINER,
|
||||
]
|
||||
|
||||
@@ -239,9 +230,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
seed = int(
|
||||
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
)
|
||||
seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)))
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
@@ -281,9 +270,7 @@ if __name__ == "__main__":
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
specs2 = model_specs[version2]
|
||||
state2 = init_st(
|
||||
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
|
||||
)
|
||||
state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap)
|
||||
params2 = state2["params"]
|
||||
|
||||
params2.img2img_strength = st.number_input(
|
||||
@@ -309,7 +296,7 @@ if __name__ == "__main__":
|
||||
if mode == "txt2img":
|
||||
out = run_txt2img(
|
||||
state=state,
|
||||
version=str(version),
|
||||
model_id=version_enum,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
return_latents=add_pipeline,
|
||||
|
||||
@@ -48,7 +48,7 @@ def init_st(
|
||||
state["model"] = pipeline
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
state["config"] = config
|
||||
state["params"] = SamplingParams()
|
||||
state["params"] = spec.default_params
|
||||
if load_filter:
|
||||
state["filter"] = DeepFloydDataFiltering(verbose=False)
|
||||
else:
|
||||
@@ -132,9 +132,7 @@ def show_samples(samples, outputs):
|
||||
|
||||
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
)
|
||||
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
|
||||
)
|
||||
|
||||
if params.guider == Guider.VANILLA:
|
||||
@@ -165,14 +163,10 @@ def init_sampling(
|
||||
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
|
||||
|
||||
params.steps = int(
|
||||
st.sidebar.number_input(
|
||||
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
||||
)
|
||||
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
|
||||
)
|
||||
|
||||
params.sampler = Sampler(
|
||||
@@ -220,15 +214,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
)
|
||||
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
|
||||
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
||||
|
||||
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
params.order = int(
|
||||
st.sidebar.number_input("order", value=params.order, min_value=1)
|
||||
)
|
||||
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
|
||||
return params
|
||||
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ from typing import Optional, Dict, Any, Union
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
|
||||
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
|
||||
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
|
||||
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
|
||||
SD_2_1 = "stable-diffusion-v2-1"
|
||||
@@ -59,24 +59,21 @@ class Thresholder(str, Enum):
|
||||
class SamplingParams:
|
||||
"""
|
||||
Parameters for sampling.
|
||||
The defaults here are derived from user preference testing.
|
||||
They will be subject to change in the future, likely pulled
|
||||
from model specs instead of global defaults.
|
||||
"""
|
||||
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 40
|
||||
width: int
|
||||
height: int
|
||||
steps: int
|
||||
sampler: Sampler = Sampler.EULER_EDM
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
thresholder: Thresholder = Thresholder.NONE
|
||||
scale: float = 5.0
|
||||
scale: float
|
||||
aesthetic_score: float = 6.0
|
||||
negative_aesthetic_score: float = 2.5
|
||||
img2img_strength: float = 1.0
|
||||
orig_width: int = width
|
||||
orig_height: int = height
|
||||
orig_width: int = 1024
|
||||
orig_height: int = 1024
|
||||
crop_coords_top: int = 0
|
||||
crop_coords_left: int = 0
|
||||
sigma_min: float = 0.0292
|
||||
@@ -100,8 +97,10 @@ class SamplingSpec:
|
||||
config: str
|
||||
ckpt: str
|
||||
is_guided: bool
|
||||
default_params: SamplingParams
|
||||
|
||||
|
||||
# The defaults here are derived from user preference testing.
|
||||
model_specs = {
|
||||
ModelArchitecture.SD_2_1: SamplingSpec(
|
||||
height=512,
|
||||
@@ -112,6 +111,12 @@ model_specs = {
|
||||
config="sd_2_1.yaml",
|
||||
ckpt="v2-1_512-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=512,
|
||||
height=512,
|
||||
steps=40,
|
||||
scale=7.0,
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SD_2_1_768: SamplingSpec(
|
||||
height=768,
|
||||
@@ -122,6 +127,12 @@ model_specs = {
|
||||
config="sd_2_1_768.yaml",
|
||||
ckpt="v2-1_768-ema-pruned.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=768,
|
||||
height=768,
|
||||
steps=40,
|
||||
scale=7.0,
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
@@ -132,6 +143,7 @@ model_specs = {
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_0.9.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
|
||||
),
|
||||
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
@@ -142,8 +154,11 @@ model_specs = {
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_0.9.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
|
||||
),
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
|
||||
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
@@ -152,8 +167,9 @@ model_specs = {
|
||||
config="sd_xl_base.yaml",
|
||||
ckpt="sd_xl_base_1.0.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
|
||||
),
|
||||
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
|
||||
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
|
||||
height=1024,
|
||||
width=1024,
|
||||
channels=4,
|
||||
@@ -162,10 +178,39 @@ model_specs = {
|
||||
config="sd_xl_refiner.yaml",
|
||||
ckpt="sd_xl_refiner_1.0.safetensors",
|
||||
is_guided=True,
|
||||
default_params=SamplingParams(
|
||||
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def wrap_discretization(
|
||||
discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
and noise_strength < 1.0
|
||||
and noise_strength > 0.0
|
||||
and steps is not None
|
||||
):
|
||||
discretization = Txt2NoisyDiscretizationWrapper(
|
||||
discretization,
|
||||
strength=noise_strength,
|
||||
original_steps=steps,
|
||||
)
|
||||
return discretization
|
||||
|
||||
|
||||
class SamplingPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -231,17 +276,19 @@ class SamplingPipeline:
|
||||
|
||||
def text_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
prompt: str,
|
||||
params: Optional[SamplingParams] = None,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter=None,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization = wrap_discretization(
|
||||
sampler.discretization,
|
||||
image_strength=None,
|
||||
noise_strength=noise_strength,
|
||||
@@ -270,18 +317,20 @@ class SamplingPipeline:
|
||||
|
||||
def image_to_image(
|
||||
self,
|
||||
params: SamplingParams,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
params: Optional[SamplingParams] = None,
|
||||
negative_prompt: str = "",
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter=None,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
sampler.discretization = self.wrap_discretization(
|
||||
sampler.discretization = wrap_discretization(
|
||||
sampler.discretization,
|
||||
image_strength=params.img2img_strength,
|
||||
noise_strength=noise_strength,
|
||||
@@ -308,44 +357,20 @@ class SamplingPipeline:
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
def wrap_discretization(
|
||||
self, discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
and noise_strength < 1.0
|
||||
and noise_strength > 0.0
|
||||
and steps is not None
|
||||
):
|
||||
discretization = Txt2NoisyDiscretizationWrapper(
|
||||
discretization,
|
||||
strength=noise_strength,
|
||||
original_steps=steps,
|
||||
)
|
||||
return discretization
|
||||
|
||||
def refiner(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
params: SamplingParams = SamplingParams(
|
||||
sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15
|
||||
),
|
||||
params: Optional[SamplingParams] = None,
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
filter: Any = None,
|
||||
add_noise: bool = False,
|
||||
):
|
||||
if params is None:
|
||||
params = self.specs.default_params
|
||||
|
||||
sampler = get_sampler_config(params)
|
||||
value_dict = {
|
||||
"orig_width": image.shape[3] * 8,
|
||||
|
||||
@@ -27,7 +27,7 @@ class TestInference:
|
||||
@fixture(
|
||||
scope="class",
|
||||
params=[
|
||||
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
|
||||
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
|
||||
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
|
||||
],
|
||||
ids=["SDXL_V1", "SDXL_V0_9"],
|
||||
|
||||
Reference in New Issue
Block a user