mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-04 13:24:28 +01:00
abstract device defaults
This commit is contained in:
@@ -174,7 +174,7 @@ class SamplingPipeline:
|
||||
model_path: Optional[str] = None,
|
||||
config_path: Optional[str] = None,
|
||||
use_fp16: bool = True,
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sampling pipeline for generating images from a model.
|
||||
@@ -211,16 +211,13 @@ class SamplingPipeline:
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
if not isinstance(device, DeviceModelManager):
|
||||
self.device_manager = get_model_manager(device=device)
|
||||
else:
|
||||
self.device_manager = device
|
||||
|
||||
self.device_manager = get_model_manager(device)
|
||||
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
model = load_model_from_config(config, self.ckpt)
|
||||
@@ -268,7 +265,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device_manager=self.device_manager,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
@@ -308,7 +305,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device_manager=self.device_manager,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
def wrap_discretization(
|
||||
@@ -377,7 +374,7 @@ class SamplingPipeline:
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
add_noise=add_noise,
|
||||
device_manager=self.device_manager,
|
||||
device=self.device_manager,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ class DeviceModelManager(object):
|
||||
class CudaModelManager(DeviceModelManager):
|
||||
"""
|
||||
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
|
||||
"""
|
||||
"""
|
||||
|
||||
@contextlib.contextmanager
|
||||
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
|
||||
@@ -141,14 +141,19 @@ def perform_save_locally(save_path, samples):
|
||||
base_count += 1
|
||||
|
||||
|
||||
def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager:
|
||||
if isinstance(device, torch.device) or isinstance(device, str):
|
||||
if torch.device(device).type == "cuda":
|
||||
return CudaModelManager(device=device)
|
||||
else:
|
||||
return DeviceModelManager(device=device)
|
||||
else:
|
||||
def get_model_manager(
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]]
|
||||
) -> DeviceModelManager:
|
||||
if isinstance(device, DeviceModelManager):
|
||||
return device
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = torch.device(device)
|
||||
if device.type == "cuda":
|
||||
return CudaModelManager(device=device)
|
||||
else:
|
||||
return DeviceModelManager(device=device)
|
||||
|
||||
|
||||
class Img2ImgDiscretizationWrapper:
|
||||
"""
|
||||
@@ -217,13 +222,15 @@ def do_sample(
|
||||
batch2model_input: Optional[List] = None,
|
||||
return_latents=False,
|
||||
filter=None,
|
||||
device_manager: DeviceModelManager = DeviceModelManager("cuda"),
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
):
|
||||
if force_uc_zero_embeddings is None:
|
||||
force_uc_zero_embeddings = []
|
||||
if batch2model_input is None:
|
||||
batch2model_input = []
|
||||
|
||||
device_manager = get_model_manager(device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
@@ -367,8 +374,9 @@ def do_img2img(
|
||||
skip_encode=False,
|
||||
filter=None,
|
||||
add_noise=True,
|
||||
device_manager=DeviceModelManager("cuda"),
|
||||
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
|
||||
):
|
||||
device_manager = get_model_manager(device)
|
||||
with torch.no_grad():
|
||||
with device_manager.autocast():
|
||||
with model.ema_scope():
|
||||
|
||||
Reference in New Issue
Block a user