mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
some PR fixes
This commit is contained in:
@@ -191,11 +191,10 @@ def init_sampling(
|
||||
)
|
||||
)
|
||||
|
||||
params = get_discretization(params, key=key)
|
||||
params = get_discretization(params=params, key=key)
|
||||
params = get_guider(params=params, key=key)
|
||||
params = get_sampler(params=params, key=key)
|
||||
|
||||
params = get_guider(key=key, params=params)
|
||||
|
||||
params = get_sampler(params, key=key)
|
||||
return params, num_rows, num_cols
|
||||
|
||||
|
||||
|
||||
@@ -165,8 +165,8 @@ class SamplingPipeline:
|
||||
model_spec: Optional[SamplingSpec] = None,
|
||||
model_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
config_path: Optional[Union[str, pathlib.Path]] = None,
|
||||
device: Union[str, torch.Device] = "cuda",
|
||||
swap_device: Optional[Union[str, torch.Device]] = None,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
swap_device: Optional[Union[str, torch.device]] = None,
|
||||
use_fp16: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
|
||||
@@ -153,7 +153,7 @@ def do_sample(
|
||||
with autocast(device) as precision_scope:
|
||||
with model.ema_scope():
|
||||
num_samples = [num_samples]
|
||||
with SwapToDevice(model.conditioner, device):
|
||||
with swap_to_device(model.conditioner, device):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -190,11 +190,11 @@ def do_sample(
|
||||
model.model, input, sigma, c, **additional_model_inputs
|
||||
)
|
||||
|
||||
with SwapToDevice(model.denoiser, device):
|
||||
with SwapToDevice(model.model, device):
|
||||
with swap_to_device(model.denoiser, device):
|
||||
with swap_to_device(model.model, device):
|
||||
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
|
||||
|
||||
with SwapToDevice(model.first_stage_model, device):
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -294,7 +294,7 @@ def do_img2img(
|
||||
with torch.no_grad():
|
||||
with autocast(device):
|
||||
with model.ema_scope():
|
||||
with SwapToDevice(model.conditioner, device):
|
||||
with swap_to_device(model.conditioner, device):
|
||||
batch, batch_uc = get_batch(
|
||||
get_unique_embedder_keys_from_conditioner(model.conditioner),
|
||||
value_dict,
|
||||
@@ -314,7 +314,7 @@ def do_img2img(
|
||||
if skip_encode:
|
||||
z = img
|
||||
else:
|
||||
with SwapToDevice(model.first_stage_model, device):
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
z = model.encode_first_stage(img)
|
||||
|
||||
noise = torch.randn_like(z)
|
||||
@@ -337,11 +337,11 @@ def do_img2img(
|
||||
def denoiser(x, sigma, c):
|
||||
return model.denoiser(model.model, x, sigma, c)
|
||||
|
||||
with SwapToDevice(model.denoiser, device):
|
||||
with SwapToDevice(model.model, device):
|
||||
with swap_to_device(model.denoiser, device):
|
||||
with swap_to_device(model.model, device):
|
||||
samples_z = sampler(denoiser, noised_z, cond=c, uc=uc)
|
||||
|
||||
with SwapToDevice(model.first_stage_model, device):
|
||||
with swap_to_device(model.first_stage_model, device):
|
||||
samples_x = model.decode_first_stage(samples_z)
|
||||
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
@@ -354,7 +354,7 @@ def do_img2img(
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def SwapToDevice(
|
||||
def swap_to_device(
|
||||
model: Union[torch.nn.Module, torch.Tensor], device: Union[torch.device, str]
|
||||
):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user