mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
simplify device_manager usage
This commit is contained in:
@@ -36,7 +36,7 @@ def init_st(
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
device=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
|
||||
@@ -57,6 +57,13 @@ class Thresholder(str, Enum):
|
||||
|
||||
@dataclass
|
||||
class SamplingParams:
|
||||
"""
|
||||
Parameters for sampling.
|
||||
The defaults here are derived from user preference testing.
|
||||
They will be subject to change in the future, likely pulled
|
||||
from model specs instead of global defaults.
|
||||
"""
|
||||
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
steps: int = 40
|
||||
@@ -167,7 +174,9 @@ class SamplingPipeline:
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
device_manager: DeviceModelManager = CudaModelManager(device="cuda"),
|
||||
device: Union[DeviceModelManager, str, torch.device] = CudaModelManager(
|
||||
device="cuda"
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -177,7 +186,7 @@ class SamplingPipeline:
|
||||
@param model_path: Path to model checkpoints folder.
|
||||
@param config_path: Path to model config folder.
|
||||
@param use_fp16: Whether to use fp16 for sampling.
|
||||
@param model_loader: Model loader class to use. Defaults to CudaModelLoader.
|
||||
@param device: Device manager to use with this pipeline. If a string or torch.device is passed, a device manager will be created based on device type if possible.
|
||||
"""
|
||||
|
||||
self.model_id = model_id
|
||||
@@ -205,7 +214,13 @@ class SamplingPipeline:
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
self.device_manager = device_manager
|
||||
if isinstance(device, torch.device) or isinstance(device, str):
|
||||
if torch.device(device).type == "cuda":
|
||||
self.device_manager = CudaModelManager(device=device)
|
||||
else:
|
||||
self.device_manager = DeviceModelManager(device=device)
|
||||
else:
|
||||
self.device_manager = device
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user