mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 06:34:27 +01:00
formatting, remove reference
This commit is contained in:
@@ -105,7 +105,9 @@ def load_img(display=True, key=None, device="cuda"):
|
||||
st.image(image)
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image = np.array(image.convert("RGB"))
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
@@ -193,7 +195,9 @@ def run_img2img(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
strength = st.number_input("**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0)
|
||||
strength = st.number_input(
|
||||
"**Img2Img Strength**", value=0.75, min_value=0.0, max_value=1.0
|
||||
)
|
||||
sampler, num_rows, num_cols = init_sampling(
|
||||
img2img_strength=strength,
|
||||
stage2strength=stage2strength,
|
||||
@@ -280,8 +284,6 @@ if __name__ == "__main__":
|
||||
save_locally, save_path = init_save_locally(os.path.join(SAVE_PATH, version))
|
||||
|
||||
state = init_st(version_dict, load_filter=True)
|
||||
if state["msg"]:
|
||||
st.info(state["msg"])
|
||||
model = state["model"]
|
||||
|
||||
is_legacy = version_dict["is_legacy"]
|
||||
@@ -308,7 +310,6 @@ if __name__ == "__main__":
|
||||
|
||||
version_dict2 = VERSION2SPECS[version2]
|
||||
state2 = init_st(version_dict2, load_filter=False)
|
||||
st.info(state2["msg"])
|
||||
|
||||
stage2strength = st.number_input(
|
||||
"**Refinement strength**", value=0.15, min_value=0.0, max_value=1.0
|
||||
|
||||
@@ -34,7 +34,9 @@ def init_st(version_dict, load_ckpt=True, load_filter=True):
|
||||
ckpt = version_dict["ckpt"]
|
||||
|
||||
config = OmegaConf.load(config)
|
||||
model = load_model_from_config(config, ckpt if load_ckpt else None, freeze=False)
|
||||
model = load_model_from_config(
|
||||
config, ckpt if load_ckpt else None, freeze=False
|
||||
)
|
||||
|
||||
state["model"] = model
|
||||
state["ckpt"] = ckpt if load_ckpt else None
|
||||
@@ -154,9 +156,13 @@ def get_guider(key):
|
||||
)
|
||||
|
||||
if guider == "IdentityGuider":
|
||||
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif guider == "VanillaCFG":
|
||||
scale = st.number_input(f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0)
|
||||
scale = st.number_input(
|
||||
f"cfg-scale #{key}", value=5.0, min_value=0.0, max_value=100.0
|
||||
)
|
||||
|
||||
thresholder = st.sidebar.selectbox(
|
||||
f"Thresholder #{key}",
|
||||
@@ -189,9 +195,13 @@ def init_sampling(
|
||||
):
|
||||
num_rows, num_cols = 1, 1
|
||||
if specify_num_samples:
|
||||
num_cols = st.number_input(f"num cols #{key}", value=2, min_value=1, max_value=10)
|
||||
num_cols = st.number_input(
|
||||
f"num cols #{key}", value=2, min_value=1, max_value=10
|
||||
)
|
||||
|
||||
steps = st.sidebar.number_input(f"steps #{key}", value=40, min_value=1, max_value=1000)
|
||||
steps = st.sidebar.number_input(
|
||||
f"steps #{key}", value=40, min_value=1, max_value=1000
|
||||
)
|
||||
sampler = st.sidebar.selectbox(
|
||||
f"Sampler #{key}",
|
||||
[
|
||||
@@ -218,7 +228,9 @@ def init_sampling(
|
||||
|
||||
sampler = get_sampler(sampler, steps, discretization_config, guider_config, key=key)
|
||||
if img2img_strength < 1.0:
|
||||
st.warning(f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper")
|
||||
st.warning(
|
||||
f"Wrapping {sampler.__class__.__name__} with Img2ImgDiscretizationWrapper"
|
||||
)
|
||||
sampler.discretization = Img2ImgDiscretizationWrapper(
|
||||
sampler.discretization, strength=img2img_strength
|
||||
)
|
||||
@@ -279,7 +291,10 @@ def get_sampler(sampler_name, steps, discretization_config, guider_config, key=1
|
||||
s_noise=s_noise,
|
||||
verbose=True,
|
||||
)
|
||||
elif sampler_name == "EulerAncestralSampler" or sampler_name == "DPMPP2SAncestralSampler":
|
||||
elif (
|
||||
sampler_name == "EulerAncestralSampler"
|
||||
or sampler_name == "DPMPP2SAncestralSampler"
|
||||
):
|
||||
s_noise = st.sidebar.number_input("s_noise", value=1.0, min_value=0.0)
|
||||
eta = st.sidebar.number_input("eta", value=1.0, min_value=0.0)
|
||||
|
||||
|
||||
@@ -181,20 +181,30 @@ class SamplingPipeline:
|
||||
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
|
||||
if not os.path.exists(model_path):
|
||||
# This supports development installs where checkpoints is root level of the repo
|
||||
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
|
||||
model_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "checkpoints"
|
||||
)
|
||||
if config_path is None:
|
||||
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
)
|
||||
if not os.path.exists(config_path):
|
||||
# This supports development installs where configs is root level of the repo
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "configs/inference"
|
||||
)
|
||||
self.config = str(config_path / self.specs.config)
|
||||
self.ckpt = str(model_path / self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
)
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
|
||||
@@ -221,7 +231,9 @@ class SamplingPipeline:
|
||||
sampler = get_sampler_config(params)
|
||||
if stage2strength is not None:
|
||||
sampler.discretization = Txt2NoisyDiscretizationWrapper(
|
||||
sampler.discretization, strength=stage2strength, original_steps=params.steps
|
||||
sampler.discretization,
|
||||
strength=stage2strength,
|
||||
original_steps=params.steps,
|
||||
)
|
||||
value_dict = asdict(params)
|
||||
value_dict["prompt"] = prompt
|
||||
@@ -275,7 +287,10 @@ class SamplingPipeline:
|
||||
)
|
||||
|
||||
def wrap_discretization(self, discretization, strength=1.0):
|
||||
if not isinstance(discretization, Img2ImgDiscretizationWrapper) and strength < 1.0:
|
||||
if (
|
||||
not isinstance(discretization, Img2ImgDiscretizationWrapper)
|
||||
and strength < 1.0
|
||||
):
|
||||
return Img2ImgDiscretizationWrapper(discretization, strength=strength)
|
||||
return discretization
|
||||
|
||||
@@ -322,7 +337,9 @@ class SamplingPipeline:
|
||||
|
||||
def get_guider_config(params: SamplingParams):
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
elif params.guider == Guider.VANILLA:
|
||||
scale = params.scale
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ class WatermarkEmbedder:
|
||||
if squeeze:
|
||||
image = image[None, ...]
|
||||
n = image.shape[0]
|
||||
image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[
|
||||
:, :, :, ::-1
|
||||
]
|
||||
image_np = rearrange(
|
||||
(255 * image).detach().cpu(), "n b c h w -> (n b) h w c"
|
||||
).numpy()[:, :, :, ::-1]
|
||||
# torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255]
|
||||
for k in range(image_np.shape[0]):
|
||||
image_np[k] = self.encoder.encode(image_np[k], "dwtDct")
|
||||
@@ -184,7 +184,9 @@ def do_sample(
|
||||
randn = torch.randn(shape).to(device)
|
||||
|
||||
def denoiser(input, sigma, c):
|
||||
return model.denoiser(model.model, input, sigma, c, **additional_model_inputs)
|
||||
return model.denoiser(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
with ModelOnDevice(model.denoiser, device):
|
||||
with ModelOnDevice(model.model, device):
|
||||
@@ -211,10 +213,14 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
for key in keys:
|
||||
if key == "txt":
|
||||
batch["txt"] = (
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
np.repeat([value_dict["prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
batch_uc["txt"] = (
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N)).reshape(N).tolist()
|
||||
np.repeat([value_dict["negative_prompt"]], repeats=math.prod(N))
|
||||
.reshape(N)
|
||||
.tolist()
|
||||
)
|
||||
elif key == "original_size_as_tuple":
|
||||
batch["original_size_as_tuple"] = (
|
||||
@@ -224,7 +230,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
)
|
||||
elif key == "crop_coords_top_left":
|
||||
batch["crop_coords_top_left"] = (
|
||||
torch.tensor([value_dict["crop_coords_top"], value_dict["crop_coords_left"]])
|
||||
torch.tensor(
|
||||
[value_dict["crop_coords_top"], value_dict["crop_coords_left"]]
|
||||
)
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
@@ -233,7 +241,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
torch.tensor([value_dict["aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
)
|
||||
batch_uc["aesthetic_score"] = (
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]]).to(device).repeat(*N, 1)
|
||||
torch.tensor([value_dict["negative_aesthetic_score"]])
|
||||
.to(device)
|
||||
.repeat(*N, 1)
|
||||
)
|
||||
|
||||
elif key == "target_size_as_tuple":
|
||||
@@ -254,7 +264,9 @@ def get_batch(keys, value_dict, N: Union[List, ListConfig], device="cuda"):
|
||||
def get_input_image_tensor(image: Image.Image, device="cuda"):
|
||||
w, h = image.size
|
||||
print(f"loaded input image of size ({w}, {h})")
|
||||
width, height = map(lambda x: x - x % 64, (w, h)) # resize to integer multiple of 64
|
||||
width, height = map(
|
||||
lambda x: x - x % 64, (w, h)
|
||||
) # resize to integer multiple of 64
|
||||
image = image.resize((width, height))
|
||||
image_array = np.array(image.convert("RGB"))
|
||||
image_array = image_array[None].transpose(0, 3, 1, 2)
|
||||
|
||||
Reference in New Issue
Block a user