mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
extract path resolution method, fix/improve device swapping support
This commit is contained in:
@@ -33,11 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
config = spec.config
|
||||
ckpt = spec.ckpt
|
||||
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=lowvram_mode,
|
||||
device="cpu" if lowvram_mode else "cuda",
|
||||
)
|
||||
if lowvram_mode:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu"
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda")
|
||||
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
|
||||
@@ -19,7 +19,7 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
import torch
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
|
||||
class ModelArchitecture(str, Enum):
|
||||
@@ -163,11 +163,24 @@ class SamplingPipeline:
|
||||
self,
|
||||
model_id: Optional[ModelArchitecture] = None,
|
||||
model_spec: Optional[SamplingSpec] = None,
|
||||
model_path=None,
|
||||
config_path=None,
|
||||
device="cuda",
|
||||
use_fp16=True,
|
||||
model_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
config_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
device: Union[str, torch.Device] = "cuda",
|
||||
swap_device: Optional[Union[str, torch.Device]] = None,
|
||||
use_fp16: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
|
||||
@param model_id: Model architecture to use. If not specified, model_spec must be specified.
|
||||
@param model_spec: Model specification to use. If not specified, model_id must be specified.
|
||||
@param model_path: Path to model checkpoints folder.
|
||||
@param config_path: Path to model config folder.
|
||||
@param device: Device to use for sampling.
|
||||
@param swap_device: Device to swap models to when not in use.
|
||||
@param use_fp16: Whether to use fp16 for sampling.
|
||||
"""
|
||||
|
||||
self.model_id = model_id
|
||||
if model_spec is not None:
|
||||
self.specs = model_spec
|
||||
@@ -179,23 +192,9 @@ class SamplingPipeline:
|
||||
raise ValueError("Either model_id or model_spec should be provided")
|
||||
|
||||
if model_path is None:
|
||||
model_path = pathlib.Path(__file__).parent.parent.resolve() / "checkpoints"
|
||||
if not os.path.exists(model_path):
|
||||
# This supports development installs where checkpoints is root level of the repo
|
||||
model_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "checkpoints"
|
||||
)
|
||||
model_path = self._resolve_default_path("checkpoints")
|
||||
if config_path is None:
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.resolve() / "configs/inference"
|
||||
)
|
||||
if not os.path.exists(config_path):
|
||||
# This supports development installs where configs is root level of the repo
|
||||
config_path = (
|
||||
pathlib.Path(__file__).parent.parent.parent.resolve()
|
||||
/ "configs/inference"
|
||||
)
|
||||
config_path = self._resolve_default_path("configs/inference")
|
||||
self.config = str(pathlib.Path(config_path) / self.specs.config)
|
||||
self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
@@ -207,7 +206,22 @@ class SamplingPipeline:
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
self.device = device
|
||||
self.model = self._load_model(device=device, use_fp16=use_fp16)
|
||||
self.swap_device = swap_device
|
||||
load_device = device if swap_device is None else swap_device
|
||||
self.model = self._load_model(device=load_device, use_fp16=use_fp16)
|
||||
|
||||
def _resolve_default_path(self, suffix: str) -> pathlib.Path:
|
||||
# Resolves a path relative to the root of the module or repo
|
||||
repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix
|
||||
module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix
|
||||
path = module_path / suffix
|
||||
if not os.path.exists(path):
|
||||
path = repo_path / suffix
|
||||
if not os.path.exists(path):
|
||||
raise ValueError(
|
||||
f"Default locations for {suffix} not found, please specify path"
|
||||
)
|
||||
return pathlib.Path(path)
|
||||
|
||||
def _load_model(self, device="cuda", use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
@@ -256,6 +270,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
@@ -293,6 +308,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def wrap_discretization(
|
||||
@@ -361,6 +377,7 @@ class SamplingPipeline:
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
add_noise=add_noise,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user