From 44943df4f218cda2265a0f7bdf4db4ffa501a504 Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Thu, 3 Aug 2023 23:59:42 +0000 Subject: [PATCH] Allow loading custom models and improve path logic --- sgm/inference/api.py | 34 +++++++++++++++++++++++++--------- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ec17dfe..182e8cf 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -1,6 +1,7 @@ from dataclasses import dataclass, asdict from enum import Enum from omegaconf import OmegaConf +import os import pathlib from sgm.inference.helpers import ( do_sample, @@ -158,18 +159,33 @@ model_specs = { class SamplingPipeline: def __init__( self, - model_id: ModelArchitecture, - model_path="checkpoints", - config_path="configs/inference", + model_id: Optional[ModelArchitecture] = None, + model_spec: Optional[SamplingSpec] = None, + model_path=None, + config_path=None, device="cuda", use_fp16=True, - ) -> None: - if model_id not in model_specs: - raise ValueError(f"Model {model_id} not supported") + ) -> None: self.model_id = model_id - self.specs = model_specs[self.model_id] - self.config = str(pathlib.Path(config_path, self.specs.config)) - self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt)) + if model_spec is not None: + self.specs = model_spec + elif model_id is not None: + if model_id not in model_specs: + raise ValueError(f"Model {model_id} not supported") + self.specs = model_specs[model_id] + else: + raise ValueError("Either model_id or model_spec should be provided") + + if model_path is None: + model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints" + if config_path is None: + config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference" + self.config = str(config_path / self.specs.config) + self.ckpt = str(model_path / self.specs.ckpt) + if not os.path.exists(self.config): + raise ValueError(f"Config {self.config} not found, check model spec or config_path") + if not os.path.exists(self.ckpt): + raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path") self.device = device self.model = self._load_model(device=device, use_fp16=use_fp16)