mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
Allow loading custom models and improve path logic
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from omegaconf import OmegaConf
|
||||
import os
|
||||
import pathlib
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
@@ -158,18 +159,33 @@ model_specs = {
|
||||
class SamplingPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
model_id: ModelArchitecture,
|
||||
model_path="checkpoints",
|
||||
config_path="configs/inference",
|
||||
model_id: Optional[ModelArchitecture] = None,
|
||||
model_spec: Optional[SamplingSpec] = None,
|
||||
model_path=None,
|
||||
config_path=None,
|
||||
device="cuda",
|
||||
use_fp16=True,
|
||||
) -> None:
|
||||
if model_id not in model_specs:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
) -> None:
|
||||
self.model_id = model_id
|
||||
self.specs = model_specs[self.model_id]
|
||||
self.config = str(pathlib.Path(config_path, self.specs.config))
|
||||
self.ckpt = str(pathlib.Path(model_path, self.specs.ckpt))
|
||||
if model_spec is not None:
|
||||
self.specs = model_spec
|
||||
elif model_id is not None:
|
||||
if model_id not in model_specs:
|
||||
raise ValueError(f"Model {model_id} not supported")
|
||||
self.specs = model_specs[model_id]
|
||||
else:
|
||||
raise ValueError("Either model_id or model_spec should be provided")
|
||||
|
||||
if model_path is None:
|
||||
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
|
||||
if config_path is None:
|
||||
config_path = pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
self.config = str(config_path / self.specs.config)
|
||||
self.ckpt = str(model_path / self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user