mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 12:24:27 +01:00
fix path resolution bug
This commit is contained in:
@@ -181,30 +181,20 @@ class SamplingPipeline:
|
||||
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
|
||||
if not os.path.exists(model_path):
|
||||
# This supports development installs where checkpoints is root level of the repo
|
||||
model_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "checkpoints"
|
||||
)
|
||||
model_path = pathlib.Path(__file__).parent.parent.parent.resolve() / "checkpoints"
|
||||
if config_path is None:
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
)
|
||||
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
if not os.path.exists(config_path):
|
||||
# This supports development installs where configs is root level of the repo
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "configs/inference"
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve() / "configs/inference"
|
||||
)
|
||||
self.config = str(config_path / self.specs.config)
|
||||
self.ckpt = str(model_path / self.specs.ckpt)
|
||||
self.config = str(pathlib.Path(config_path) / self.specs.config)
|
||||
self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
)
|
||||
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
|
||||
@@ -300,9 +290,7 @@ class SamplingPipeline:
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
@@ -361,9 +349,7 @@ class SamplingPipeline:
|
||||
|
||||
def get_guider_config(params: SamplingParams):
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
elif params.guider == Guider.VANILLA:
|
||||
scale = params.scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user