mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 12:24:27 +01:00
run black
This commit is contained in:
@@ -66,7 +66,9 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
@@ -85,10 +87,16 @@ def run_txt2img(
|
||||
model: SamplingPipeline = state["model"]
|
||||
params: SamplingParams = state["params"]
|
||||
if model_id in sdxl_base_model_list:
|
||||
width, height = st.selectbox("Resolution:", list(SD_XL_BASE_RATIOS.values()), 10)
|
||||
width, height = st.selectbox(
|
||||
"Resolution:", list(SD_XL_BASE_RATIOS.values()), 10
|
||||
)
|
||||
else:
|
||||
height = int(st.number_input("H", value=params.height, min_value=64, max_value=2048))
|
||||
width = int(st.number_input("W", value=params.width, min_value=64, max_value=2048))
|
||||
height = int(
|
||||
st.number_input("H", value=params.height, min_value=64, max_value=2048)
|
||||
)
|
||||
width = int(
|
||||
st.number_input("W", value=params.width, min_value=64, max_value=2048)
|
||||
)
|
||||
|
||||
params = init_embedder_options(
|
||||
get_unique_embedder_keys_from_conditioner(model.model.conditioner),
|
||||
@@ -230,7 +238,9 @@ if __name__ == "__main__":
|
||||
else:
|
||||
add_pipeline = False
|
||||
|
||||
seed = int(st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9)))
|
||||
seed = int(
|
||||
st.sidebar.number_input("seed", value=42, min_value=0, max_value=int(1e9))
|
||||
)
|
||||
seed_everything(seed)
|
||||
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, str(version)))
|
||||
@@ -270,7 +280,9 @@ if __name__ == "__main__":
|
||||
st.write("**Refiner Options:**")
|
||||
|
||||
specs2 = model_specs[version2]
|
||||
state2 = init_st(specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap)
|
||||
state2 = init_st(
|
||||
specs2, load_filter=False, use_fp16=use_fp16, enable_swap=enable_swap
|
||||
)
|
||||
params2 = state2["params"]
|
||||
|
||||
params2.img2img_strength = st.number_input(
|
||||
|
||||
@@ -132,7 +132,9 @@ def show_samples(samples, outputs):
|
||||
|
||||
def get_guider(params: SamplingParams, key=1) -> SamplingParams:
|
||||
params.guider = Guider(
|
||||
st.sidebar.selectbox(f"Discretization #{key}", [member.value for member in Guider])
|
||||
st.sidebar.selectbox(
|
||||
f"Discretization #{key}", [member.value for member in Guider]
|
||||
)
|
||||
)
|
||||
|
||||
if params.guider == Guider.VANILLA:
|
||||
@@ -163,10 +165,14 @@ def init_sampling(
|
||||
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
params.steps = int(
|
||||
st.sidebar.number_input(f"steps #{key}", value=params.steps, min_value=1, max_value=1000)
|
||||
st.sidebar.number_input(
|
||||
f"steps #{key}", value=params.steps, min_value=1, max_value=1000
|
||||
)
|
||||
)
|
||||
|
||||
params.sampler = Sampler(
|
||||
@@ -214,11 +220,15 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
)
|
||||
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input("s_noise", value=params.s_noise, min_value=0.0)
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
params.eta = st.sidebar.number_input("eta", value=params.eta, min_value=0.0)
|
||||
|
||||
elif params.sampler == Sampler.LINEAR_MULTISTEP:
|
||||
params.order = int(st.sidebar.number_input("order", value=params.order, min_value=1))
|
||||
params.order = int(
|
||||
st.sidebar.number_input("order", value=params.order, min_value=1)
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
|
||||
@@ -185,13 +185,17 @@ model_specs = {
|
||||
}
|
||||
|
||||
|
||||
def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None):
|
||||
def wrap_discretization(
|
||||
discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength)
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
@@ -245,13 +249,19 @@ class SamplingPipeline:
|
||||
self.config = os.path.join(config_path, "inference", self.specs.config)
|
||||
self.ckpt = os.path.join(model_path, self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
)
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
self.device_manager = get_model_manager(device)
|
||||
|
||||
self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16)
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
@@ -396,7 +406,9 @@ class SamplingPipeline:
|
||||
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
guider_config: Dict[str, Any]
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif params.guider == Guider.VANILLA:
|
||||
scale = params.scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user