context manager

This commit is contained in:
Stephan Auerhahn
2023-08-09 12:38:44 -07:00
parent a726ce3eb7
commit f86ffac274
2 changed files with 31 additions and 33 deletions

View File

@@ -200,12 +200,8 @@ def init_sampling(
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
if params.discretization == Discretization.EDM:
params.sigma_min = st.number_input(
f"sigma_min #{key}", value=params.sigma_min
) # 0.0292
params.sigma_max = st.number_input(
f"sigma_max #{key}", value=params.sigma_max
) # 14.6146
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
params.rho = st.number_input(f"rho #{key}", value=params.rho)
return params

View File

@@ -1,3 +1,4 @@
import contextlib
import os
from typing import Union, List, Optional
@@ -352,34 +353,35 @@ def do_img2img(
return samples
class SwapToDevice(object):
def __init__(
self,
model: Union[torch.nn.Module, torch.Tensor],
device: Union[torch.device, str],
):
self.model = model
self.device = torch.device(device)
if isinstance(model, torch.Tensor):
self.original_device = model.device
@contextlib.contextmanager
def SwapToDevice(
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
):
"""
Context manager that swaps a model or tensor to a device, and then swaps it back to its original device
when the context is exited.
"""
if isinstance(model, torch.Tensor):
original_device = model.device
else:
param = next(model.parameters(), None)
if param is not None:
original_device = param.device
else:
param = next(model.parameters(), None)
if param is not None:
self.original_device = param.device
buf = next(model.buffers(), None)
if buf is not None:
original_device = buf.device
else:
buf = next(model.buffers(), None)
if buf is not None:
self.original_device = buf.device
else:
# If device could not be found, turn this into a no-op
self.original_device = self.device
# If device could not be found, do nothing
return
device = torch.device(device)
def __enter__(self):
if self.device != self.original_device:
self.model.to(self.device)
if device != original_device:
model.to(device)
def __exit__(self, *args):
if self.device != self.original_device:
self.model.to(self.original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()
yield
if device != original_device:
model.to(original_device)
if torch.cuda.is_available():
torch.cuda.empty_cache()