rename ModelOnDevice to SwapToDevice

This commit is contained in:
Stephan Auerhahn
2023-08-06 23:46:20 +00:00
parent ced97f0e84
commit 6c18c8443a

View File

@@ -152,7 +152,7 @@ def do_sample(
with autocast(device) as precision_scope:
with model.ema_scope():
num_samples = [num_samples]
with ModelOnDevice(model.conditioner, device):
with SwapToDevice(model.conditioner, device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -189,11 +189,11 @@ def do_sample(
model.model, input, sigma, c, **additional_model_inputs
)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
with SwapToDevice(model.denoiser, device):
with SwapToDevice(model.model, device):
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
with ModelOnDevice(model.first_stage_model, device):
with SwapToDevice(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -293,7 +293,7 @@ def do_img2img(
with torch.no_grad():
with autocast(device):
with model.ema_scope():
with ModelOnDevice(model.conditioner, device):
with SwapToDevice(model.conditioner, device):
batch, batch_uc = get_batch(
get_unique_embedder_keys_from_conditioner(model.conditioner),
value_dict,
@@ -313,7 +313,7 @@ def do_img2img(
if skip_encode:
z = img
else:
with ModelOnDevice(model.first_stage_model, device):
with SwapToDevice(model.first_stage_model, device):
z = model.encode_first_stage(img)
noise = torch.randn_like(z)
@@ -336,11 +336,11 @@ def do_img2img(
def denoiser(x, sigma, c):
return model.denoiser(model.model, x, sigma, c)
with ModelOnDevice(model.denoiser, device):
with ModelOnDevice(model.model, device):
with SwapToDevice(model.denoiser, device):
with SwapToDevice(model.model, device):
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
with ModelOnDevice(model.first_stage_model, device):
with SwapToDevice(model.first_stage_model, device):
samples_x = model.decode_first_stage(samples_z)
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -352,7 +352,7 @@ def do_img2img(
return samples
class ModelOnDevice(object):
class SwapToDevice(object):
def __init__(
self,
model: Union[torch.nn.Module, torch.Tensor],