diff --git a/scripts/demo/streamlit_helpers.py b/scripts/demo/streamlit_helpers.py index ddc9c6b..dfb0056 100644 --- a/scripts/demo/streamlit_helpers.py +++ b/scripts/demo/streamlit_helpers.py @@ -325,10 +325,8 @@ def init_sampling( def get_discretization(discretization, key=1): if discretization == "LegacyDDPMDiscretization": - use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False) discretization_config = { "target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization", - "params": {"legacy_range": not use_new_range}, } elif discretization == "EDMDiscretization": sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292 diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py index f632186..f7bedc5 100644 --- a/sgm/modules/diffusionmodules/discretizer.py +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -6,6 +6,26 @@ from ...util import append_zero from ...modules.diffusionmodules.util import make_beta_schedule +def generate_roughly_equally_spaced_steps(n, m): + # 0, ..., m - 1 + m -= 1 + + # We are getting rid of leading 0 later, so increase steps + n += 1 + + # Calculate the step size + step = m / (n - 1) + + # Generate the list + steps_reversed = [int(m - i * step) for i in range(n)] + steps = steps_reversed[::-1] + + # Get rid of leading 0 + steps = steps[1:] + + return np.array(steps) + + class Discretization: def __call__(self, n, do_append_zero=True, device="cuda", flip=False): sigmas = self.get_sigmas(n, device) @@ -33,7 +53,6 @@ class LegacyDDPMDiscretization(Discretization): linear_start=0.00085, linear_end=0.0120, num_timesteps=1000, - legacy_range=True, ): self.num_timesteps = num_timesteps betas = make_beta_schedule( @@ -42,23 +61,15 @@ class LegacyDDPMDiscretization(Discretization): alphas = 1.0 - betas self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) - self.legacy_range = legacy_range def get_sigmas(self, n, device): if n < self.num_timesteps: - c = self.num_timesteps // n - - if self.legacy_range: - timesteps = np.asarray(list(range(0, self.num_timesteps, c))) - timesteps += 1 # Legacy LDM Hack - else: - timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c))) - timesteps -= 1 - timesteps = timesteps[1:] - + timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps] - else: + elif n == self.num_timesteps: alphas_cumprod = self.alphas_cumprod + else: + raise ValueError to_torch = partial(torch.tensor, dtype=torch.float32, device=device) sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5