This commit is contained in:
Stephan Auerhahn
2023-08-04 00:00:51 +00:00
parent 44943df4f2
commit baf79d2d79

View File

@@ -165,7 +165,7 @@ class SamplingPipeline:
config_path=None,
device="cuda",
use_fp16=True,
) -> None:
) -> None:
self.model_id = model_id
if model_spec is not None:
self.specs = model_spec
@@ -179,13 +179,19 @@ class SamplingPipeline:
if model_path is None:
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
if config_path is None:
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
config_path = (
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
)
self.config = str(config_path / self.specs.config)
self.ckpt = str(model_path / self.specs.ckpt)
if not os.path.exists(self.config):
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
raise ValueError(
f"Config {self.config} not found, check model spec or config_path"
)
if not os.path.exists(self.ckpt):
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device = device
self.model = self._load_model(device=device, use_fp16=use_fp16)