diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index b6814ec..5838741 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -20,9 +20,7 @@ from sgm.inference.api import ( SamplingPipeline, Thresholder, ) -from sgm.inference.helpers import ( - embed_watermark, -) +from sgm.inference.helpers import embed_watermark, CudaModelLoader @st.cache_resource() @@ -35,10 +33,12 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A if lowvram_mode: pipeline = SamplingPipeline( - model_spec=spec, use_fp16=True, device="cuda", swap_device="cpu" + model_spec=spec, + use_fp16=True, + model_loader=CudaModelLoader(device="cuda", swap_device="cpu"), ) else: - pipeline = SamplingPipeline(model_spec=spec, use_fp16=True, device="cuda") + pipeline = SamplingPipeline(model_spec=spec, use_fp16=False) state["spec"] = spec state["model"] = pipeline diff --git a/sgm/inference/api.py b/sgm/inference/api.py index 082ca18..0588a26 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -2,10 +2,11 @@ from dataclasses import dataclass, asdict from enum import Enum from omegaconf import OmegaConf import os -import pathlib from sgm.inference.helpers import ( do_sample, do_img2img, + BaseDeviceModelLoader, + CudaModelLoader, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -17,7 +18,7 @@ from sgm.modules.diffusionmodules.sampling import ( DPMPP2MSampler, LinearMultistepSampler, ) -from sgm.util import load_model_from_config +from sgm.util import load_model_from_config, get_configs_path, get_checkpoints_path import torch from typing import Optional, Dict, Any, Union @@ -163,11 +164,10 @@ class SamplingPipeline: self, model_id: Optional[ModelArchitecture] = None, model_spec: Optional[SamplingSpec] = None, - model_path: Optional[Union[str, pathlib.Path]] = None, - config_path: Optional[Union[str, pathlib.Path]] = None, - device: Union[str, torch.device] = "cuda", - swap_device: Optional[Union[str, torch.device]] = None, + model_path: Optional[str] = None, + config_path: Optional[str] = None, use_fp16: bool = True, + model_loader: BaseDeviceModelLoader = CudaModelLoader(device="cuda"), ) -> None: """ Sampling pipeline for generating images from a model. @@ -176,9 +176,8 @@ class SamplingPipeline: @param model_spec: Model specification to use. If not specified, model_id must be specified. @param model_path: Path to model checkpoints folder. @param config_path: Path to model config folder. - @param device: Device to use for sampling. - @param swap_device: Device to swap models to when not in use. @param use_fp16: Whether to use fp16 for sampling. + @param model_loader: Model loader class to use. Defaults to CudaModelLoader. """ self.model_id = model_id @@ -192,11 +191,11 @@ class SamplingPipeline: raise ValueError("Either model_id or model_spec should be provided") if model_path is None: - model_path = self._resolve_default_path("checkpoints") + model_path = get_checkpoints_path() if config_path is None: - config_path = self._resolve_default_path("configs/inference") - self.config = str(pathlib.Path(config_path) / self.specs.config) - self.ckpt = str(pathlib.Path(model_path) / self.specs.ckpt) + config_path = get_configs_path() + self.config = os.path.join(config_path, "inference", self.specs.config) + self.ckpt = os.path.join(model_path, self.specs.ckpt) if not os.path.exists(self.config): raise ValueError( f"Config {self.config} not found, check model spec or config_path" @@ -210,19 +209,6 @@ class SamplingPipeline: load_device = device if swap_device is None else swap_device self.model = self._load_model(device=load_device, use_fp16=use_fp16) - def _resolve_default_path(self, suffix: str) -> pathlib.Path: - # Resolves a path relative to the root of the module or repo - repo_path = pathlib.Path(__file__).parent.parent.parent.resolve() / suffix - module_path = pathlib.Path(__file__).parent.parent.resolve() / suffix - path = module_path / suffix - if not os.path.exists(path): - path = repo_path / suffix - if not os.path.exists(path): - raise ValueError( - f"Default locations for {suffix} not found, please specify path" - ) - return pathlib.Path(path) - def _load_model(self, device="cuda", use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index 68409a2..bef2fb3 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -10,6 +10,7 @@ from einops import rearrange from imwatermark import WatermarkEncoder from omegaconf import ListConfig from torch import autocast +from abc import ABC, abstractmethod from sgm.util import append_dims @@ -353,35 +354,67 @@ def do_img2img( return samples -@contextlib.contextmanager -def swap_to_device( - model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] -): +class BaseDeviceModelLoader(ABC): """ - Context manager that swaps a model or tensor to a device, and then swaps it back to its original device - when the context is exited. + Base class for device managers. Device managers are used to manage the device used for a model. """ - if isinstance(model, torch.Tensor): - original_device = model.device - else: - param = next(model.parameters(), None) - if param is not None: - original_device = param.device - else: - buf = next(model.buffers(), None) - if buf is not None: - original_device = buf.device - else: - # If device could not be found, do nothing - return - device = torch.device(device) - if device != original_device: - model.to(device) + @abstractmethod + def __init__(self, device: Union[torch.device, str]): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + pass - yield + def load(self, model: torch.nn.Module): + """ + Loads a model to the device. + """ + pass - if device != original_device: - model.to(original_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + @contextlib.contextmanager + def use(self, model: torch.nn.Module): + """ + Context manager that ensures a model is on the correct device during use. + """ + yield + + +class CudaModelLoader(BaseDeviceModelLoader): + """ + Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. + """ + + def __init__( + self, + device: Union[torch.device, str] = "cuda", + swap_device: Union[torch.device, str] = None, + ): + """ + Args: + device (Union[torch.device, str]): The device to use for the model. + """ + self.device = torch.device(device) + self.swap_device = ( + torch.device(swap_device) if swap_device is not None else self.device + ) + + def load(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Loads a model to the device. + """ + model.to(self.swap_device) + + @contextlib.contextmanager + def use(self, model: Union[torch.nn.Module, torch.Tensor]): + """ + Context manager that ensures a model is on the correct device during use. + """ + if self.device != self.swap_device: + model.to(self.device) + yield + if self.device != self.swap_device: + model.to(self.swap_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/sgm/util.py b/sgm/util.py index c5e68f4..1f96aeb 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -230,6 +230,24 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): return model +def get_checkpoints_path() -> str: + """ + Get the `checkpoints` directory. + This could be in the root of the repository for a working copy, + or in the cwd for other use cases. + """ + this_dir = os.path.dirname(__file__) + candidates = ( + os.path.join(this_dir, "checkpoints"), + os.path.join(os.getcwd(), "checkpoints"), + ) + for candidate in candidates: + candidate = os.path.abspath(candidate) + if os.path.isdir(candidate): + return candidate + raise FileNotFoundError(f"Could not find SGM checkpoints in {candidates}") + + def get_configs_path() -> str: """ Get the `configs` directory.