mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
context manager
This commit is contained in:
@@ -200,12 +200,8 @@ def init_sampling(
|
||||
|
||||
def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.discretization == Discretization.EDM:
|
||||
params.sigma_min = st.number_input(
|
||||
f"sigma_min #{key}", value=params.sigma_min
|
||||
) # 0.0292
|
||||
params.sigma_max = st.number_input(
|
||||
f"sigma_max #{key}", value=params.sigma_max
|
||||
) # 14.6146
|
||||
params.sigma_min = st.number_input(f"sigma_min #{key}", value=params.sigma_min)
|
||||
params.sigma_max = st.number_input(f"sigma_max #{key}", value=params.sigma_max)
|
||||
params.rho = st.number_input(f"rho #{key}", value=params.rho)
|
||||
return params
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import os
|
||||
from typing import Union, List, Optional
|
||||
|
||||
@@ -352,34 +353,35 @@ def do_img2img(
|
||||
return samples
|
||||
|
||||
|
||||
class SwapToDevice(object):
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[torch.nn.Module, torch.Tensor],
|
||||
device: Union[torch.device, str],
|
||||
):
|
||||
self.model = model
|
||||
self.device = torch.device(device)
|
||||
if isinstance(model, torch.Tensor):
|
||||
self.original_device = model.device
|
||||
@contextlib.contextmanager
|
||||
def SwapToDevice(
|
||||
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
|
||||
):
|
||||
"""
|
||||
Context manager that swaps a model or tensor to a device, and then swaps it back to its original device
|
||||
when the context is exited.
|
||||
"""
|
||||
if isinstance(model, torch.Tensor):
|
||||
original_device = model.device
|
||||
else:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
original_device = param.device
|
||||
else:
|
||||
param = next(model.parameters(), None)
|
||||
if param is not None:
|
||||
self.original_device = param.device
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
original_device = buf.device
|
||||
else:
|
||||
buf = next(model.buffers(), None)
|
||||
if buf is not None:
|
||||
self.original_device = buf.device
|
||||
else:
|
||||
# If device could not be found, turn this into a no-op
|
||||
self.original_device = self.device
|
||||
# If device could not be found, do nothing
|
||||
return
|
||||
device = torch.device(device)
|
||||
|
||||
def __enter__(self):
|
||||
if self.device != self.original_device:
|
||||
self.model.to(self.device)
|
||||
if device != original_device:
|
||||
model.to(device)
|
||||
|
||||
def __exit__(self, *args):
|
||||
if self.device != self.original_device:
|
||||
self.model.to(self.original_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
yield
|
||||
|
||||
if device != original_device:
|
||||
model.to(original_device)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
Reference in New Issue
Block a user