mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-03 12:54:27 +01:00
path helper & model swapping rewrite
This commit is contained in:
@@ -20,9 +20,7 @@ from sgm.inference.api import (
|
||||
SamplingPipeline,
|
||||
Thresholder,
|
||||
)
|
||||
from sgm.inference.helpers import (
|
||||
embed_watermark,
|
||||
)
|
||||
from sgm.inference.helpers import embed_watermark, CudaModelLoader
|
||||
|
||||
|
||||
@st.cache_resource()
|
||||
@@ -35,10 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
|
||||
if lowvram_mode:
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu"
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelLoader(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda")
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
|
||||
state["spec"] = spec
|
||||
state["model"] = pipeline
|
||||
|
||||
@@ -2,10 +2,11 @@ from dataclasses import dataclass, asdict
|
||||
from enum import Enum
|
||||
from omegaconf import OmegaConf
|
||||
import os
|
||||
import pathlib
|
||||
from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
BaseDeviceModelLoader,
|
||||
CudaModelLoader,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
)
|
||||
@@ -17,7 +18,7 @@ from sgm.modules.diffusionmodules.sampling import (
|
||||
DPMPP2MSampler,
|
||||
LinearMultistepSampler,
|
||||
)
|
||||
from sgm.util import load_model_from_config
|
||||
from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path
|
||||
import torch
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
@@ -163,11 +164,10 @@ class SamplingPipeline:
|
||||
self,
|
||||
model_id: Optional[ModelArchitecture] = None,
|
||||
model_spec: Optional[SamplingSpec] = None,
|
||||
model_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
config_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
swap_device: Optional[Union[str, torch.device]] = None,
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"),
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -176,9 +176,8 @@ class SamplingPipeline:
|
||||
@param model_spec: Model specification to use. If not specified, model_id must be specified.
|
||||
@param model_path: Path to model checkpoints folder.
|
||||
@param config_path: Path to model config folder.
|
||||
@param device: Device to use for sampling.
|
||||
@param swap_device: Device to swap models to when not in use.
|
||||
@param use_fp16: Whether to use fp16 for sampling.
|
||||
@param model_loader: Model loader class to use. Defaults to CudaModelLoader.
|
||||
"""
|
||||
|
||||
self.model_id = model_id
|
||||
@@ -192,11 +191,11 @@ class SamplingPipeline:
|
||||
raise ValueError("Either model_id or model_spec should be provided")
|
||||
|
||||
if model_path is None:
|
||||
model_path = self._resolve_default_path("checkpoints")
|
||||
model_path = get_checkpoints_path()
|
||||
if config_path is None:
|
||||
config_path = self._resolve_default_path("configs/inference")
|
||||
self.config = str(pathlib.Path(config_path) / self.specs.config)
|
||||
self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt)
|
||||
config_path = get_configs_path()
|
||||
self.config = os.path.join(config_path, "inference", self.specs.config)
|
||||
self.ckpt = os.path.join(model_path, self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
@@ -210,19 +209,6 @@ class SamplingPipeline:
|
||||
load_device = device if swap_device is None else swap_device
|
||||
self.model = self._load_model(device=load_device, use_fp16=use_fp16)
|
||||
|
||||
def _resolve_default_path(self, suffix: str) -> pathlib.Path:
|
||||
# Resolves a path relative to the root of the module or repo
|
||||
repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix
|
||||
module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix
|
||||
path = module_path / suffix
|
||||
if not os.path.exists(path):
|
||||
path = repo_path / suffix
|
||||
if not os.path.exists(path):
|
||||
raise ValueError(
|
||||
f"Default locations for {suffix} not found, please specify path"
|
||||
)
|
||||
return pathlib.Path(path)
|
||||
|
||||
def _load_model(self, device="cuda", use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
|
||||
@@ -10,6 +10,7 @@ from einops import rearrange
|
||||
from imwatermark import WatermarkEncoder
|
||||
from omegaconf import ListConfig
|
||||
from torch import autocast
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from sgm.util import append_dims
|
||||
|
||||
@@ -353,35 +354,67 @@ def do_img2img(
|
||||
return samples
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_to_device(
|
||||
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
|
||||
):
|
||||
class BaseDeviceModelLoader(ABC):
|
||||
"""
|
||||
Context manager that swaps a model or tensor to a device, and then swaps it back to its original device
|
||||
when the context is exited.
|
||||
Base class for device managers. Device managers are used to manage the device used for a model.
|
||||
"""
|
||||
if isinstance(model, torch.Tensor):
|
||||
original_device = model.device
|
||||
else:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
original_device = param.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
original_device = buf.device
|
||||
else:
|
||||
# If device could not be found, do nothing
|
||||
return
|
||||
device = torch.device(device)
|
||||
|
||||
if device != original_device:
|
||||
model.to(device)
|
||||
@abstractmethod
|
||||
def __init__(self, device: Union[torch.device, str]):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
pass
|
||||
|
||||
yield
|
||||
def load(self, model: torch.nn.Module):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
pass
|
||||
|
||||
if device != original_device:
|
||||
model.to(original_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: torch.nn.Module):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
"""
|
||||
yield
|
||||
|
||||
|
||||
class CudaModelLoader(BaseDeviceModelLoader):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Union[torch.device, str] = "cuda",
|
||||
swap_device: Union[torch.device, str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
"""
|
||||
self.device = torch.device(device)
|
||||
self.swap_device = (
|
||||
torch.device(swap_device) if swap_device is not None else self.device
|
||||
)
|
||||
|
||||
def load(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Loads a model to the device.
|
||||
"""
|
||||
model.to(self.swap_device)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
"""
|
||||
Context manager that ensures a model is on the correct device during use.
|
||||
"""
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.device)
|
||||
yield
|
||||
if self.device != self.swap_device:
|
||||
model.to(self.swap_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
18
sgm/util.py
18
sgm/util.py
@@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
return model
|
||||
|
||||
|
||||
def get_checkpoints_path() -> str:
|
||||
"""
|
||||
Get the `checkpoints` directory.
|
||||
This could be in the root of the repository for a working copy,
|
||||
or in the cwd for other use cases.
|
||||
"""
|
||||
this_dir = os.path.dirname(__file__)
|
||||
candidates = (
|
||||
os.path.join(this_dir, "checkpoints"),
|
||||
os.path.join(os.getcwd(), "checkpoints"),
|
||||
)
|
||||
for candidate in candidates:
|
||||
candidate = os.path.abspath(candidate)
|
||||
if os.path.isdir(candidate):
|
||||
return candidate
|
||||
raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}")
|
||||
|
||||
|
||||
def get_configs_path() -> str:
|
||||
"""
|
||||
Get the `configs` directory.
|
||||
|
||||
Reference in New Issue
Block a user