more fixes and cleanup

This commit is contained in:
Stephan Auerhahn
2023-08-10 05:11:34 -07:00
parent 9b18e6fa19
commit de7a627978
2 changed files with 10 additions and 13 deletions

View File

@@ -35,7 +35,7 @@ def init_st(spec: SamplingSpec, load_ckpt=True, load_filter=True) -> Dict[str, A
pipeline = SamplingPipeline(
model_spec=spec,
use_fp16=True,
model_loader=CudaModelManager(device="cuda", swap_device="cpu"),
device_manager=CudaModelManager(device="cuda", swap_device="cpu"),
)
else:
pipeline = SamplingPipeline(model_spec=spec, use_fp16=False)
@@ -207,7 +207,7 @@ def get_discretization(params: SamplingParams, key=1) -> SamplingParams:
def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
if params.sampler == Sampler.EULER_EDM or params.sampler == Sampler.HEUN_EDM:
if params.sampler in (Sampler.EULER_EDM, Sampler.HEUN_EDM):
params.s_churn = st.sidebar.number_input(
f"s_churn #{key}", value=params.s_churn, min_value=0.0
)
@@ -221,10 +221,7 @@ def get_sampler(params: SamplingParams, key=1) -> SamplingParams:
f"s_noise #{key}", value=params.s_noise, min_value=0.0
)
elif (
params.sampler == Sampler.EULER_ANCESTRAL
or params.sampler == Sampler.DPMPP2S_ANCESTRAL
):
elif params.sampler in (Sampler.EULER_ANCESTRAL, Sampler.DPMPP2S_ANCESTRAL):
params.s_noise = st.sidebar.number_input(
"s_noise", value=params.s_noise, min_value=0.0
)