Test model device manager and fix bugs

This commit is contained in:
Stephan Auerhahn
2023-08-12 07:15:36 +00:00
parent fe4632034b
commit d4307bef5d
4 changed files with 72 additions and 16 deletions

View File

@@ -6,7 +6,7 @@ from sgm.inference.helpers import (
do_sample,
do_img2img,
DeviceModelManager,
CudaModelManager,
get_model_manager,
Img2ImgDiscretizationWrapper,
Txt2NoisyDiscretizationWrapper,
)
@@ -174,9 +174,7 @@ class SamplingPipeline:
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
device="cuda"
),
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
) -> None:
"""
Sampling pipeline for generating images from a model.
@@ -213,18 +211,16 @@ class SamplingPipeline:
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
if isinstance(device, torch.device) or isinstance(device, str):
if torch.device(device).type == "cuda":
self.device_manager = CudaModelManager(device=device)
else:
self.device_manager = DeviceModelManager(device=device)
if not isinstance(device, DeviceModelManager):
self.device_manager = get_model_manager(device=device)
else:
self.device_manager = device
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)