diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index 2d6972e..b6814ec 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -191,11 +191,10 @@ def init_sampling( ) ) - params = get_discretization(params, key=key) + params = get_discretization(params=params, key=key) + params = get_guider(params=params, key=key) + params = get_sampler(params=params, key=key) - params = get_guider(key=key, params=params) - - params = get_sampler(params, key=key) return params, num_rows, num_cols diff --git a/sgm/inference/api.py b/sgm/inference/api.py index ecdf706..082ca18 100644 --- a/sgm/inference/api.py +++ b/sgm/inference/api.py @@ -165,8 +165,8 @@ class SamplingPipeline: model_spec: Optional[SamplingSpec] = None, model_path: Optional[Union[str, pathlib.Path]] = None, config_path: Optional[Union[str, pathlib.Path]] = None, - device: Union[str, torch.Device] = "cuda", - swap_device: Optional[Union[str, torch.Device]] = None, + device: Union[str, torch.device] = "cuda", + swap_device: Optional[Union[str, torch.device]] = None, use_fp16: bool = True, ) -> None: """ diff --git a/sgm/inference/helpers.py b/sgm/inference/helpers.py index aa9e8cd..68409a2 100644 --- a/sgm/inference/helpers.py +++ b/sgm/inference/helpers.py @@ -153,7 +153,7 @@ def do_sample( with autocast(device) as precision_scope: with model.ema_scope(): num_samples = [num_samples] - with SwapToDevice(model.conditioner, device): + with swap_to_device(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -190,11 +190,11 @@ def do_sample( model.model, input, sigma, c, **additional_model_inputs ) - with SwapToDevice(model.denoiser, device): - with SwapToDevice(model.model, device): + with swap_to_device(model.denoiser, device): + with swap_to_device(model.model, device): samples_z = sampler(denoiser, randn, cond=c, uc=uc) - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -294,7 +294,7 @@ def do_img2img( with torch.no_grad(): with autocast(device): with model.ema_scope(): - with SwapToDevice(model.conditioner, device): + with swap_to_device(model.conditioner, device): batch, batch_uc = get_batch( get_unique_embedder_keys_from_conditioner(model.conditioner), value_dict, @@ -314,7 +314,7 @@ def do_img2img( if skip_encode: z = img else: - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): z = model.encode_first_stage(img) noise = torch.randn_like(z) @@ -337,11 +337,11 @@ def do_img2img( def denoiser(x, sigma, c): return model.denoiser(model.model, x, sigma, c) - with SwapToDevice(model.denoiser, device): - with SwapToDevice(model.model, device): + with swap_to_device(model.denoiser, device): + with swap_to_device(model.model, device): samples_z = sampler(denoiser, noised_z, cond=c, uc=uc) - with SwapToDevice(model.first_stage_model, device): + with swap_to_device(model.first_stage_model, device): samples_x = model.decode_first_stage(samples_z) samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0) @@ -354,7 +354,7 @@ def do_img2img( @contextlib.contextmanager -def SwapToDevice( +def swap_to_device( model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str] ): """