Test model device manager and fix bugs

This commit is contained in:
Stephan Auerhahn
2023-08-12 07:15:36 +00:00
parent fe4632034b
commit d4307bef5d
4 changed files with 72 additions and 16 deletions

View File

@@ -44,5 +44,5 @@ dependencies = [
test-inference = [
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
"pip install -r requirements/pt2.txt",
"pytest -v tests/inference/test_inference.py {args}",
"pytest -v tests/inference {args}",
]

View File

@@ -6,7 +6,7 @@ from sgm.inference.helpers import (
do_sample,
do_img2img,
DeviceModelManager,
CudaModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
@@ -174,9 +174,7 @@ class SamplingPipeline:
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
device="cuda"
),
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
) -> None:
"""
Sampling pipeline for generating images from a model.
@@ -213,18 +211,16 @@ class SamplingPipeline:
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
if isinstance(device, torch.device) or isinstance(device, str):
if torch.device(device).type == "cuda":
self.device_manager = CudaModelManager(device=device)
else:
self.device_manager = DeviceModelManager(device=device)
if not isinstance(device, DeviceModelManager):
self.device_manager = get_model_manager(device=device)
else:
self.device_manager = device
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)

View File

@@ -67,7 +67,7 @@ class DeviceModelManager(object):
self,
device: Union[torch.device, str],
swap_device: Optional[Union[torch.device, str]] = None,
):
) -> None:
"""
Args:
device (Union[torch.device, str]): The device to use for the model.
@@ -77,11 +77,11 @@ class DeviceModelManager(object):
torch.device(swap_device) if swap_device is not None else self.device
)
def load(self, model: torch.nn.Module):
def load(self, model: torch.nn.Module) -> None:
"""
Loads a model to the device.
Loads a model to the (swap) device.
"""
return model.to(self.device)
model.to(self.swap_device)
def autocast(self):
"""
@@ -109,7 +109,7 @@ class DeviceModelManager(object):
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
@@ -141,6 +141,15 @@ def perform_save_locally(save_path, samples):
base_count += 1
def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager:
if isinstance(device, torch.device) or isinstance(device, str):
if torch.device(device).type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
else:
return device
class Img2ImgDiscretizationWrapper:
"""
wraps a discretizer, and prunes the sigmas

View File

@@ -0,0 +1,51 @@
import numpy
from PIL import Image
import pytest
from pytest import fixture
import torch
from typing import Tuple, Optional
from sgm.inference.api import (
model_specs,
SamplingParams,
SamplingPipeline,
Sampler,
ModelArchitecture,
)
import sgm.inference.helpers as helpers
def get_torch_device(model: torch.nn.Module) -> torch.device:
param = next(model.parameters(), None)
if param is not None:
return param.device
else:
buf = next(model.buffers(), None)
if buf is not None:
return buf.device
else:
raise TypeError("Could not determine device of input model")
@pytest.mark.inference
def test_default_loading():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cuda"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
@pytest.mark.inference
def test_model_swapping():
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
assert get_torch_device(pipeline.model.model).type == "cpu"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
with pipeline.device_manager.use(pipeline.model.model):
assert get_torch_device(pipeline.model.model).type == "cuda"
assert get_torch_device(pipeline.model.model).type == "cpu"
with pipeline.device_manager.use(pipeline.model.conditioner):
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
assert get_torch_device(pipeline.model.conditioner).type == "cpu"