simplify device_manager usage

This commit is contained in:
Stephan Auerhahn
2023-08-10 13:05:30 -07:00
parent 88395261d8
commit 3816aaa639
2 changed files with 19 additions and 4 deletions

View File

@@ -36,7 +36,7 @@ def init_st(
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
device=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)

View File

@@ -57,6 +57,13 @@ class Thresholder(str, Enum):
@dataclass
class SamplingParams:
"""
Parameters for sampling.
The defaults here are derived from user preference testing.
They will be subject to change in the future, likely pulled
from model specs instead of global defaults.
"""
width: int = 1024
height: int = 1024
steps: int = 40
@@ -167,7 +174,9 @@ class SamplingPipeline:
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device_manager: DeviceModelManager = CudaModelManager(device="cuda"),
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
device="cuda"
),
) -> None:
"""
Sampling pipeline for generating images from a model.
@@ -177,7 +186,7 @@ class SamplingPipeline:
@param model_path: Path to model checkpoints folder.
@param config_path: Path to model config folder.
@param use_fp16: Whether to use fp16 for sampling.
@param model_loader: Model loader class to use. Defaults to CudaModelLoader.
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
"""
self.model_id = model_id
@@ -205,7 +214,13 @@ class SamplingPipeline:
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
self.device_manager = device_manager
if isinstance(device, torch.device) or isinstance(device, str):
if torch.device(device).type == "cuda":
self.device_manager = CudaModelManager(device=device)
else:
self.device_manager = DeviceModelManager(device=device)
else:
self.device_manager = device
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)