PR fixes, model specific defaults

This commit is contained in:
Stephan Auerhahn
2023-08-12 05:33:16 -07:00
parent c0655731d5
commit fbe93fc53b
4 changed files with 88 additions and 86 deletions

View File

@@ -66,9 +66,7 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -78,26 +76,19 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img(
state,
version: str,
model_id: ModelArchitecture,
prompt: str,
negative_prompt: str,
return_latents=False,
stage2strength=None,
):
spec: SamplingSpec = state["spec"]
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
if model_id in sdxl_base_model_list:
width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
else:
height = int(
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
)
width = int(
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
)
height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048))
width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048))
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
@@ -207,12 +198,12 @@ def apply_refiner(
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_BASE,
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_REFINER,
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
@@ -239,9 +230,7 @@ if __name__ == "__main__":
else:
add_pipeline = False
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)))
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
@@ -281,9 +270,7 @@ if __name__ == "__main__":
st.write("**Refiner Options:**")
specs2 = model_specs[version2]
state2 = init_st(
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap)
params2 = state2["params"]
params2.img2img_strength = st.number_input(
@@ -309,7 +296,7 @@ if __name__ == "__main__":
if mode == "txt2img":
out = run_txt2img(
state=state,
version=str(version),
model_id=version_enum,
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,

View File

@@ -48,7 +48,7 @@ def init_st(
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = SamplingParams()
state["params"] = spec.default_params
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
@@ -132,9 +132,7 @@ def show_samples(samples, outputs):
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
)
if params.guider == Guider.VANILLA:
@@ -165,14 +163,10 @@ def init_sampling(
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
)
params.sampler = Sampler(
@@ -220,15 +214,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
)
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
return params

View File

@@ -24,8 +24,8 @@ from typing import Optional, Dict, Any, Union
class ModelArchitecture(str, Enum):
SDXL_V1_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V1_0_BASE = "stable-diffusion-xl-v1-base"
SDXL_V1_0_REFINER = "stable-diffusion-xl-v1-refiner"
SDXL_V0_9_BASE = "stable-diffusion-xl-v0-9-base"
SDXL_V0_9_REFINER = "stable-diffusion-xl-v0-9-refiner"
SD_2_1 = "stable-diffusion-v2-1"
@@ -59,24 +59,21 @@ class Thresholder(str, Enum):
class SamplingParams:
"""
Parameters for sampling.
The defaults here are derived from user preference testing.
They will be subject to change in the future, likely pulled
from model specs instead of global defaults.
"""
width: int = 1024
height: int = 1024
steps: int = 40
width: int
height: int
steps: int
sampler: Sampler = Sampler.EULER_EDM
discretization: Discretization = Discretization.LEGACY_DDPM
guider: Guider = Guider.VANILLA
thresholder: Thresholder = Thresholder.NONE
scale: float = 5.0
scale: float
aesthetic_score: float = 6.0
negative_aesthetic_score: float = 2.5
img2img_strength: float = 1.0
orig_width: int = width
orig_height: int = height
orig_width: int = 1024
orig_height: int = 1024
crop_coords_top: int = 0
crop_coords_left: int = 0
sigma_min: float = 0.0292
@@ -100,8 +97,10 @@ class SamplingSpec:
config: str
ckpt: str
is_guided: bool
default_params: SamplingParams
# The defaults here are derived from user preference testing.
model_specs = {
ModelArchitecture.SD_2_1: SamplingSpec(
height=512,
@@ -112,6 +111,12 @@ model_specs = {
config="sd_2_1.yaml",
ckpt="v2-1_512-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=512,
height=512,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SD_2_1_768: SamplingSpec(
height=768,
@@ -122,6 +127,12 @@ model_specs = {
config="sd_2_1_768.yaml",
ckpt="v2-1_768-ema-pruned.safetensors",
is_guided=True,
default_params=SamplingParams(
width=768,
height=768,
steps=40,
scale=7.0,
),
),
ModelArchitecture.SDXL_V0_9_BASE: SamplingSpec(
height=1024,
@@ -132,6 +143,7 @@ model_specs = {
config="sd_xl_base.yaml",
ckpt="sd_xl_base_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V0_9_REFINER: SamplingSpec(
height=1024,
@@ -142,8 +154,11 @@ model_specs = {
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_0.9.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
ModelArchitecture.SDXL_V1_BASE: SamplingSpec(
ModelArchitecture.SDXL_V1_0_BASE: SamplingSpec(
height=1024,
width=1024,
channels=4,
@@ -152,8 +167,9 @@ model_specs = {
config="sd_xl_base.yaml",
ckpt="sd_xl_base_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(width=1024, height=1024, steps=40, scale=5.0),
),
ModelArchitecture.SDXL_V1_REFINER: SamplingSpec(
ModelArchitecture.SDXL_V1_0_REFINER: SamplingSpec(
height=1024,
width=1024,
channels=4,
@@ -162,10 +178,39 @@ model_specs = {
config="sd_xl_refiner.yaml",
ckpt="sd_xl_refiner_1.0.safetensors",
is_guided=True,
default_params=SamplingParams(
width=1024, height=1024, steps=40, scale=5.0, img2img_strength=0.15
),
),
}
def wrap_discretization(
discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
class SamplingPipeline:
def __init__(
self,
@@ -231,17 +276,19 @@ class SamplingPipeline:
def text_to_image(
self,
params: SamplingParams,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
sampler.discretization = self.wrap_discretization(
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=None,
noise_strength=noise_strength,
@@ -270,18 +317,20 @@ class SamplingPipeline:
def image_to_image(
self,
params: SamplingParams,
image: torch.Tensor,
prompt: str,
params: Optional[SamplingParams] = None,
negative_prompt: str = "",
samples: int = 1,
return_latents: bool = False,
noise_strength: Optional[float] = None,
filter=None,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
sampler.discretization = self.wrap_discretization(
sampler.discretization = wrap_discretization(
sampler.discretization,
image_strength=params.img2img_strength,
noise_strength=noise_strength,
@@ -308,44 +357,20 @@ class SamplingPipeline:
device=self.device_manager,
)
def wrap_discretization(
self, discretization, image_strength=None, noise_strength=None, steps=None
):
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
discretization, Txt2NoisyDiscretizationWrapper
):
return discretization # Already wrapped
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
discretization = Img2ImgDiscretizationWrapper(
discretization, strength=image_strength
)
if (
noise_strength is not None
and noise_strength < 1.0
and noise_strength > 0.0
and steps is not None
):
discretization = Txt2NoisyDiscretizationWrapper(
discretization,
strength=noise_strength,
original_steps=steps,
)
return discretization
def refiner(
self,
image: torch.Tensor,
prompt: str,
negative_prompt: str = "",
params: SamplingParams = SamplingParams(
sampler=Sampler.EULER_EDM, steps=40, img2img_strength=0.15
),
params: Optional[SamplingParams] = None,
samples: int = 1,
return_latents: bool = False,
filter: Any = None,
add_noise: bool = False,
):
if params is None:
params = self.specs.default_params
sampler = get_sampler_config(params)
value_dict = {
"orig_width": image.shape[3] * 8,

View File

@@ -27,7 +27,7 @@ class TestInference:
@fixture(
scope="class",
params=[
[ModelArchitecture.SDXL_V1_BASE, ModelArchitecture.SDXL_V1_REFINER],
[ModelArchitecture.SDXL_V1_0_BASE, ModelArchitecture.SDXL_V1_0_REFINER],
[ModelArchitecture.SDXL_V0_9_BASE, ModelArchitecture.SDXL_V0_9_REFINER],
],
ids=["SDXL_V1", "SDXL_V0_9"],