diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 9a90f3b..70b7d06 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -200,12 +200,8 @@ def init_sampling( def get_discretization(params: SamplingParams, key=1) -> SamplingParams: if params.discretization == Discretization.EDM: - params.sigma_min = st.number_input( - f"sigma_min #{key}", value=params.sigma_min - ) # 0.0292 - params.sigma_max = st.number_input( - f"sigma_max #{key}", value=params.sigma_max - ) # 14.6146 + params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min) + params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max) params.rho = st.number_input(f"rho #{key}", value=params.rho) return params diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index a0c9e22..aa9e8cd 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -1,3 +1,4 @@ +import contextlib import os from typing import Union, List, Optional @@ -352,34 +353,35 @@ def do_img2img( return samples -class SwapToDevice(object): - def __init__( - self, - model: Union[torch.nn.Module, torch.Tensor], - device: Union[torch.device, str], - ): - self.model = model - self.device = torch.device(device) - if isinstance(model, torch.Tensor): - self.original_device = model.device +@contextlib.contextmanager +def SwapToDevice( + model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] +): + """ + Context manager that swaps a model or tensor to a device, and then swaps it back to its original device + when the context is exited. + """ + if isinstance(model, torch.Tensor): + original_device = model.device + else: + param = next(model.parameters(), None) + if param is not None: + original_device = param.device else: - param = next(model.parameters(), None) - if param is not None: - self.original_device = param.device + buf = next(model.buffers(), None) + if buf is not None: + original_device = buf.device else: - buf = next(model.buffers(), None) - if buf is not None: - self.original_device = buf.device - else: - # If device could not be found, turn this into a no-op - self.original_device = self.device + # If device could not be found, do nothing + return + device = torch.device(device) - def __enter__(self): - if self.device != self.original_device: - self.model.to(self.device) + if device != original_device: + model.to(device) - def __exit__(self, *args): - if self.device != self.original_device: - self.model.to(self.original_device) - if torch.cuda.is_available(): - torch.cuda.empty_cache() + yield + + if device != original_device: + model.to(original_device) + if torch.cuda.is_available(): + torch.cuda.empty_cache()