mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-20 05:04:38 +01:00
more fixes and cleanup
This commit is contained in:
@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
|
||||
pipeline = SamplingPipeline(
|
||||
model_spec=spec,
|
||||
use_fp16=True,
|
||||
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
|
||||
)
|
||||
else:
|
||||
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
|
||||
@@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
|
||||
|
||||
|
||||
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
|
||||
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
|
||||
params.s_churn = st.sidebar.number_input(
|
||||
f"s_churn #{key}", value=params.s_churn, min_value=0.0
|
||||
)
|
||||
@@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
|
||||
f"s_noise #{key}", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
|
||||
elif (
|
||||
params.sampler == Sampler.EULER_ANCESTRAL
|
||||
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
|
||||
):
|
||||
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
|
||||
params.s_noise = st.sidebar.number_input(
|
||||
"s_noise", value=params.s_noise, min_value=0.0
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user