mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-23 14:44:31 +01:00
Test model device manager and fix bugs
This commit is contained in:
@@ -6,7 +6,7 @@ from sgm.inference.helpers import (
|
||||
do_sample,
|
||||
do_img2img,
|
||||
DeviceModelManager,
|
||||
CudaModelManager,
|
||||
get_model_manager,
|
||||
Img2ImgDiscretizationWrapper,
|
||||
Txt2NoisyDiscretizationWrapper,
|
||||
)
|
||||
@@ -174,9 +174,7 @@ class SamplingPipeline:
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
|
||||
device="cuda"
|
||||
),
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -213,18 +211,16 @@ class SamplingPipeline:
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
if isinstance(device, torch.device) or isinstance(device, str):
|
||||
if torch.device(device).type == "cuda":
|
||||
self.device_manager = CudaModelManager(device=device)
|
||||
else:
|
||||
self.device_manager = DeviceModelManager(device=device)
|
||||
if not isinstance(device, DeviceModelManager):
|
||||
self.device_manager = get_model_manager(device=device)
|
||||
else:
|
||||
self.device_manager = device
|
||||
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
|
||||
Reference in New Issue
Block a user