diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index b576669..a0c9e22 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -152,7 +152,7 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - with ModelOnDevice(model.conditioner, device): + with SwapToDevice(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -189,11 +189,11 @@ def do_sample( model.model, input, sigma, c, **additional_model_inputs ) - with ModelOnDevice(model.denoiser, device): - with ModelOnDevice(model.model, device): + with SwapToDevice(model.denoiser, device): + with SwapToDevice(model.model, device): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -293,7 +293,7 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - with ModelOnDevice(model.conditioner, device): + with SwapToDevice(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -313,7 +313,7 @@ def do_img2img( if skip_encode: z = img else: - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -336,11 +336,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with ModelOnDevice(model.denoiser, device): - with ModelOnDevice(model.model, device): + with SwapToDevice(model.denoiser, device): + with SwapToDevice(model.model, device): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with ModelOnDevice(model.first_stage_model, device): + with SwapToDevice(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -352,7 +352,7 @@ def do_img2img( return samples -class ModelOnDevice(object): +class SwapToDevice(object): def __init__( self, model: Union[torch.nn.Module, torch.Tensor],