mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
Test model device manager and fix bugs
This commit is contained in:
@@ -44,5 +44,5 @@ dependencies = [
|
||||
test-inference = [
|
||||
"pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118",
|
||||
"pip install -r requirements/pt2.txt",
|
||||
"pytest -v tests/inference/test_inference.py {args}",
|
||||
"pytest -v tests/inference {args}",
|
||||
]
|
||||
|
||||
@@ -6,7 +6,7 @@ from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
DeviceModelManager,
|
||||
CudaModelManager,
|
||||
get_model_manager,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
)
|
||||
@@ -174,9 +174,7 @@ class SamplingPipeline:
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
|
||||
device="cuda"
|
||||
),
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -213,18 +211,16 @@ class SamplingPipeline:
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
if isinstance(device, torch.device) or isinstance(device, str):
|
||||
if torch.device(device).type == "cuda":
|
||||
self.device_manager = CudaModelManager(device=device)
|
||||
else:
|
||||
self.device_manager = DeviceModelManager(device=device)
|
||||
if not isinstance(device, DeviceModelManager):
|
||||
self.device_manager = get_model_manager(device=device)
|
||||
else:
|
||||
self.device_manager = device
|
||||
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
|
||||
@@ -67,7 +67,7 @@ class DeviceModelManager(object):
|
||||
self,
|
||||
device: Union[torch.device, str],
|
||||
swap_device: Optional[Union[torch.device, str]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
device (Union[torch.device, str]): The device to use for the model.
|
||||
@@ -77,11 +77,11 @@ class DeviceModelManager(object):
|
||||
torch.device(swap_device) if swap_device is not None else self.device
|
||||
)
|
||||
|
||||
def load(self, model: torch.nn.Module):
|
||||
def load(self, model: torch.nn.Module) -> None:
|
||||
"""
|
||||
Loads a model to the device.
|
||||
Loads a model to the (swap) device.
|
||||
"""
|
||||
return model.to(self.device)
|
||||
model.to(self.swap_device)
|
||||
|
||||
def autocast(self):
|
||||
"""
|
||||
@@ -109,7 +109,7 @@ class DeviceModelManager(object):
|
||||
class CudaModelManager(DeviceModelManager):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
@@ -141,6 +141,15 @@ def perform_save_locally(save_path, samples):
|
||||
base_count += 1
|
||||
|
||||
|
||||
def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager:
|
||||
if isinstance(device, torch.device) or isinstance(device, str):
|
||||
if torch.device(device).type == "cuda":
|
||||
return CudaModelManager(device=device)
|
||||
else:
|
||||
return DeviceModelManager(device=device)
|
||||
else:
|
||||
return device
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
wraps a discretizer, and prunes the sigmas
|
||||
|
||||
51
tests/inference/test_modelmanager.py
Normal file
51
tests/inference/test_modelmanager.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import numpy
|
||||
from PIL import Image
|
||||
import pytest
|
||||
from pytest import fixture
|
||||
import torch
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from sgm.inference.api import (
|
||||
model_specs,
|
||||
SamplingParams,
|
||||
SamplingPipeline,
|
||||
Sampler,
|
||||
ModelArchitecture,
|
||||
)
|
||||
import sgm.inference.helpers as helpers
|
||||
|
||||
def get_torch_device(model: torch.nn.Module) -> torch.device:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
return param.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
return buf.device
|
||||
else:
|
||||
raise TypeError("Could not determine device of input model")
|
||||
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_default_loading():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1)
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
|
||||
@pytest.mark.inference
|
||||
def test_model_swapping():
|
||||
pipeline = SamplingPipeline(model_id=ModelArchitecture.SD_2_1, device=helpers.CudaModelManager(device="cuda", swap_device="cpu"))
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.model):
|
||||
assert get_torch_device(pipeline.model.model).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.model).type == "cpu"
|
||||
with pipeline.device_manager.use(pipeline.model.conditioner):
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cuda"
|
||||
assert get_torch_device(pipeline.model.conditioner).type == "cpu"
|
||||
Reference in New Issue
Block a user