PR fixes, model specific defaults

This commit is contained in:
Stephan Auerhahn
2023-08-12 05:33:16 -07:00
parent c0655731d5
commit fbe93fc53b
4 changed files with 88 additions and 86 deletions

View File

@@ -66,9 +66,7 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -78,26 +76,19 @@ def load_img(display=True, key=None, device="cuda"):
def run_txt2img(
state,
version: str,
model_id: ModelArchitecture,
prompt: str,
negative_prompt: str,
return_latents=False,
stage2strength=None,
):
spec: SamplingSpec = state["spec"]
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
if version.startswith("stable-diffusion-xl") and version.endswith("-base"):
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
if model_id in sdxl_base_model_list:
width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
else:
height = int(
st.number_input("H", value=spec.height, min_value=64, max_value=2048)
)
width = int(
st.number_input("W", value=spec.width, min_value=64, max_value=2048)
)
height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048))
width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048))
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
@@ -207,12 +198,12 @@ def apply_refiner(
sdxl_base_model_list = [
ModelArchitecture.SDXL_V1_BASE,
ModelArchitecture.SDXL_V1_0_BASE,
ModelArchitecture.SDXL_V0_9_BASE,
]
sdxl_refiner_model_list = [
ModelArchitecture.SDXL_V1_REFINER,
ModelArchitecture.SDXL_V1_0_REFINER,
ModelArchitecture.SDXL_V0_9_REFINER,
]
@@ -239,9 +230,7 @@ if __name__ == "__main__":
else:
add_pipeline = False
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)))
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
@@ -281,9 +270,7 @@ if __name__ == "__main__":
st.write("**Refiner Options:**")
specs2 = model_specs[version2]
state2 = init_st(
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap)
params2 = state2["params"]
params2.img2img_strength = st.number_input(
@@ -309,7 +296,7 @@ if __name__ == "__main__":
if mode == "txt2img":
out = run_txt2img(
state=state,
version=str(version),
model_id=version_enum,
prompt=prompt,
negative_prompt=negative_prompt,
return_latents=add_pipeline,

View File

@@ -48,7 +48,7 @@ def init_st(
state["model"] = pipeline
state["ckpt"] = ckpt if load_ckpt else None
state["config"] = config
state["params"] = SamplingParams()
state["params"] = spec.default_params
if load_filter:
state["filter"] = DeepFloydDataFiltering(verbose=False)
else:
@@ -132,9 +132,7 @@ def show_samples(samples, outputs):
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
params.guider = Guider(
st.sidebar.selectbox(
f"Discretization #{key}", [member.value for member in Guider]
)
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
)
if params.guider == Guider.VANILLA:
@@ -165,14 +163,10 @@ def init_sampling(
num_rows, num_cols = 1, 1
if specify_num_samples:
num_cols = st.number_input(
f"num cols #{key}", value=2, min_value=1, max_value=10
)
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
params.steps = int(
st.sidebar.number_input(
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
)
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
)
params.sampler = Sampler(
@@ -220,15 +214,11 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
)
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
elif params.sampler == Sampler.LINEAR_MULTISTEP:
params.order = int(
st.sidebar.number_input("order", value=params.order, min_value=1)
)
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
return params