mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-03 21:04:28 +01:00
more fixes and cleanup
This commit is contained in:
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
@@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
||||
|
||||
|
||||
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
|
||||
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
|
||||
params.s_churn = st.sidebar.number_input(
|
||||
f"s_churn #{key}", value=params.s_churn, min_value=0.0
|
||||
)
|
||||
@@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
f"s_noise #{key}", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
|
||||
elif (
|
||||
params.sampler == Sampler.EULER_ANCESTRAL
|
||||
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
|
||||
):
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
|
||||
@@ -205,9 +205,9 @@ class SamplingPipeline:
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
|
||||
self.model_manager = device_manager
|
||||
self.device_manager = device_manager
|
||||
self.model = self._load_model(
|
||||
device_manager=self.model_manager, use_fp16=use_fp16
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
@@ -229,7 +229,7 @@ class SamplingPipeline:
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter: Any = None,
|
||||
filter=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
@@ -257,7 +257,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
model_manager=self.model_manager,
|
||||
device_manager=self.device_manager,
|
||||
)
|
||||
|
||||
def image_to_image(
|
||||
@@ -269,7 +269,7 @@ class SamplingPipeline:
|
||||
samples: int = 1,
|
||||
return_latents: bool = False,
|
||||
noise_strength: Optional[float] = None,
|
||||
filter: Any = None,
|
||||
filter=None,
|
||||
):
|
||||
sampler = get_sampler_config(params)
|
||||
|
||||
@@ -295,7 +295,7 @@ class SamplingPipeline:
|
||||
force_uc_zero_embeddings=["txt"] if not self.specs.is_legacy else [],
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
device=self.device,
|
||||
device_manager=self.device_manager,
|
||||
)
|
||||
|
||||
def wrap_discretization(
|
||||
@@ -364,7 +364,7 @@ class SamplingPipeline:
|
||||
return_latents=return_latents,
|
||||
filter=filter,
|
||||
add_noise=add_noise,
|
||||
device=self.device,
|
||||
device_manager=self.device_manager,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user