diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py index f7bedc5..33d716c 100644 --- a/sgm/modules/diffusionmodules/discretizer.py +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -6,24 +6,10 @@ from ...util import append_zero from ...modules.diffusionmodules.util import make_beta_schedule -def generate_roughly_equally_spaced_steps(n, m): - # 0, ..., m - 1 - m -= 1 - - # We are getting rid of leading 0 later, so increase steps - n += 1 - - # Calculate the step size - step = m / (n - 1) - - # Generate the list - steps_reversed = [int(m - i * step) for i in range(n)] - steps = steps_reversed[::-1] - - # Get rid of leading 0 - steps = steps[1:] - - return np.array(steps) +def generate_roughly_equally_spaced_steps( + num_substeps: int, max_step: int +) -> np.ndarray: + return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] class Discretization: