run black

This commit is contained in:
Stephan Auerhahn
2023-08-12 05:40:25 -07:00
parent 5fde7e73b8
commit 65c6ec1cec
3 changed files with 51 additions and 17 deletions

View File

@@ -66,7 +66,9 @@ def load_img(display=True, key=None, device="cuda"):
st.image(image)
w, h = image.size
print(f"loaded input image of size ({w}, {h})")
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
width, height = map(
lambda x: x - x % 64, (w, h)
) # resize to integer multiple of 64
image = image.resize((width, height))
image = np.array(image.convert("RGB"))
image = image[None].transpose(0, 3, 1, 2)
@@ -85,10 +87,16 @@ def run_txt2img(
model: SamplingPipeline = state["model"]
params: SamplingParams = state["params"]
if model_id in sdxl_base_model_list:
width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
width, height = st.selectbox(
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
)
else:
height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048))
width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048))
height = int(
st.number_input("H", value=params.height, min_value=64, max_value=2048)
)
width = int(
st.number_input("W", value=params.width, min_value=64, max_value=2048)
)
params = init_embedder_options(
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
@@ -230,7 +238,9 @@ if __name__ == "__main__":
else:
add_pipeline = False
seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)))
seed = int(
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
)
seed_everything(seed)
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
@@ -270,7 +280,9 @@ if __name__ == "__main__":
st.write("**Refiner Options:**")
specs2 = model_specs[version2]
state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap)
state2 = init_st(
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
)
params2 = state2["params"]
params2.img2img_strength = st.number_input(