abstract device defaults

This commit is contained in:
Stephan Auerhahn
2023-08-12 07:27:25 +00:00
parent 98c4b7753b
commit f6704532a0
2 changed files with 24 additions and 19 deletions

View File

@@ -174,7 +174,7 @@ class SamplingPipeline:
model_path: Optional[str] = None,
config_path: Optional[str] = None,
use_fp16: bool = True,
device: Optional[Union[DeviceModelManager, str, torch.device]] = None
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
) -> None:
"""
Sampling pipeline for generating images from a model.
@@ -211,16 +211,13 @@ class SamplingPipeline:
raise ValueError(
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
)
if not isinstance(device, DeviceModelManager):
self.device_manager = get_model_manager(device=device)
else:
self.device_manager = device
self.device_manager = get_model_manager(device)
self.model = self._load_model(
device_manager=self.device_manager, use_fp16=use_fp16
)
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
config = OmegaConf.load(self.config)
model = load_model_from_config(config, self.ckpt)
@@ -268,7 +265,7 @@ class SamplingPipeline:
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device_manager=self.device_manager,
device=self.device_manager,
)
def image_to_image(
@@ -308,7 +305,7 @@ class SamplingPipeline:
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
return_latents=return_latents,
filter=filter,
device_manager=self.device_manager,
device=self.device_manager,
)
def wrap_discretization(
@@ -377,7 +374,7 @@ class SamplingPipeline:
return_latents=return_latents,
filter=filter,
add_noise=add_noise,
device_manager=self.device_manager,
device=self.device_manager,
)

View File

@@ -109,7 +109,7 @@ class DeviceModelManager(object):
class CudaModelManager(DeviceModelManager):
"""
Device manager that loads a model to a CUDA device, optionally swapping to CPU when not in use.
"""
"""
@contextlib.contextmanager
def use(self, model: Union[torch.nn.Module, torch.Tensor]):
@@ -141,14 +141,19 @@ def perform_save_locally(save_path, samples):
base_count += 1
def get_model_manager(device: Union[str,torch.device]) -> DeviceModelManager:
if isinstance(device, torch.device) or isinstance(device, str):
if torch.device(device).type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
else:
def get_model_manager(
device: Optional[Union[DeviceModelManager, str, torch.device]]
) -> DeviceModelManager:
if isinstance(device, DeviceModelManager):
return device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
if device.type == "cuda":
return CudaModelManager(device=device)
else:
return DeviceModelManager(device=device)
class Img2ImgDiscretizationWrapper:
"""
@@ -217,13 +222,15 @@ def do_sample(
batch2model_input: Optional[List] = None,
return_latents=False,
filter=None,
device_manager: DeviceModelManager = DeviceModelManager("cuda"),
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
):
if force_uc_zero_embeddings is None:
force_uc_zero_embeddings = []
if batch2model_input is None:
batch2model_input = []
device_manager = get_model_manager(device=device)
with torch.no_grad():
with device_manager.autocast():
with model.ema_scope():
@@ -367,8 +374,9 @@ def do_img2img(
skip_encode=False,
filter=None,
add_noise=True,
device_manager=DeviceModelManager("cuda"),
device: Optional[Union[DeviceModelManager, str, torch.device]] = None,
):
device_manager = get_model_manager(device)
with torch.no_grad():
with device_manager.autocast():
with model.ema_scope():