From d4307bef5d75435b67b6907fbb21e49a5efdce6b Mon Sep 17 00:00:00 2001 From: Stephan Auerhahn Date: Sat, 12 Aug 2023 07:15:36 +0000 Subject: [PATCH] Test model device manager and fix bugs --- pyproject.toml | 2 +- sgm/inference/api.py | 16 ++++----- sgm/inference/helpers.py | 19 ++++++++--- tests/inference/test_modelmanager.py | 51 ++++++++++++++++++++++++++++ 4 files changed, 72 insertions(+), 16 deletions(-) create mode 100644 tests/inference/test_modelmanager.py diff --git a/pyproject.toml b/pyproject.toml index 2cc5021..94ba68d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,5 +44,5 @@ dependencies = [ test-inference = [ "pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118", "pip install -r requirements/pt2.txt", - "pytest -v tests/inference/test_inference.py {args}", + "pytest -v tests/inference {args}", ] diff --git a/sgm/inference/api.py b/sgm/inference/api.py index eccf129..afb1f72 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -6,7 +6,7 @@ from sgm.inference.helpers import ( do_sample, do_img2img, DeviceModelManager, - CudaModelManager, + get_model_manager, Img2ImgDiscretizationWrapper, Txt2NoisyDiscretizationWrapper, ) @@ -174,9 +174,7 @@ class SamplingPipeline: model_path: Optional[str] = None, config_path: Optional[str] = None, use_fp16: bool = True, - device: Union[DeviceModelManager, str, torch.device] = CudaModelManager( - device="cuda" - ), + device: Optional[Union[DeviceModelManager, str, torch.device]] = None ) -> None: """ Sampling pipeline for generating images from a model. @@ -213,18 +211,16 @@ class SamplingPipeline: raise ValueError( f"Checkpoint {self.ckpt} not found, check model spec or config_path" ) - - if isinstance(device, torch.device) or isinstance(device, str): - if torch.device(device).type == "cuda": - self.device_manager = CudaModelManager(device=device) - else: - self.device_manager = DeviceModelManager(device=device) + if not isinstance(device, DeviceModelManager): + self.device_manager = get_model_manager(device=device) else: self.device_manager = device + self.model = self._load_model( device_manager=self.device_manager, use_fp16=use_fp16 ) + def _load_model(self, device_manager: DeviceModelManager, use_fp16=True): config = OmegaConf.load(self.config) model = load_model_from_config(config, self.ckpt) diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index f86eda6..addefe3 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -67,7 +67,7 @@ class DeviceModelManager(object): self, device: Union[torch.device, str], swap_device: Optional[Union[torch.device, str]] = None, - ): + ) -> None: """ Args: device (Union[torch.device, str]): The device to use for the model. @@ -77,11 +77,11 @@ class DeviceModelManager(object): torch.device(swap_device) if swap_device is not None else self.device ) - def load(self, model: torch.nn.Module): + def load(self, model: torch.nn.Module) -> None: """ - Loads a model to the device. + Loads a model to the (swap) device. """ - return model.to(self.device) + model.to(self.swap_device) def autocast(self): """ @@ -109,7 +109,7 @@ class DeviceModelManager(object): class CudaModelManager(DeviceModelManager): """ Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use. - """ + """ @contextlib.contextmanager def use(self, model: Union[torch.nn.Module, torch.Tensor]): @@ -141,6 +141,15 @@ def perform_save_locally(save_path, samples): base_count += 1 +def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager: + if isinstance(device, torch.device) or isinstance(device, str): + if torch.device(device).type == "cuda": + return CudaModelManager(device=device) + else: + return DeviceModelManager(device=device) + else: + return device + class Img2ImgDiscretizationWrapper: """ wraps a discretizer, and prunes the sigmas diff --git a/tests/inference/test_modelmanager.py b/tests/inference/test_modelmanager.py new file mode 100644 index 0000000..fa96d70 --- /dev/null +++ b/tests/inference/test_modelmanager.py @@ -0,0 +1,51 @@ +import numpy +from PIL import Image +import pytest +from pytest import fixture +import torch +from typing import Tuple, Optional + +from sgm.inference.api import ( + model_specs, + SamplingParams, + SamplingPipeline, + Sampler, + ModelArchitecture, +) +import sgm.inference.helpers as helpers + +def get_torch_device(model: torch.nn.Module) -> torch.device: + param = next(model.parameters(), None) + if param is not None: + return param.device + else: + buf = next(model.buffers(), None) + if buf is not None: + return buf.device + else: + raise TypeError("Could not determine device of input model") + + +@pytest.mark.inference +def test_default_loading(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1) + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cuda" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + +@pytest.mark.inference +def test_model_swapping(): + pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu")) + assert get_torch_device(pipeline.model.model).type == "cpu" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" + with pipeline.device_manager.use(pipeline.model.model): + assert get_torch_device(pipeline.model.model).type == "cuda" + assert get_torch_device(pipeline.model.model).type == "cpu" + with pipeline.device_manager.use(pipeline.model.conditioner): + assert get_torch_device(pipeline.model.conditioner).type == "cuda" + assert get_torch_device(pipeline.model.conditioner).type == "cpu" \ No newline at end of file