diff --git a/sgm/modules/diffusionmodules/discretizer.py b/sgm/modules/diffusionmodules/discretizer.py index 33d716c..0e4e1a7 100644 --- a/sgm/modules/diffusionmodules/discretizer.py +++ b/sgm/modules/diffusionmodules/discretizer.py @@ -1,6 +1,7 @@ import torch import numpy as np from functools import partial +from abc import abstractmethod from ...util import append_zero from ...modules.diffusionmodules.util import make_beta_schedule @@ -13,11 +14,15 @@ def generate_roughly_equally_spaced_steps( class Discretization: - def __call__(self, n, do_append_zero=True, device="cuda", flip=False): - sigmas = self.get_sigmas(n, device) + def __call__(self, n, do_append_zero=True, device="cpu", flip=False): + sigmas = self.get_sigmas(n, device=device) sigmas = append_zero(sigmas) if do_append_zero else sigmas return sigmas if not flip else torch.flip(sigmas, (0,)) + @abstractmethod + def get_sigmas(self, n, device): + raise NotImplementedError("abstract class should not be called") + class EDMDiscretization(Discretization): def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0): @@ -25,7 +30,7 @@ class EDMDiscretization(Discretization): self.sigma_max = sigma_max self.rho = rho - def get_sigmas(self, n, device): + def get_sigmas(self, n, device="cpu"): ramp = torch.linspace(0, 1, n, device=device) min_inv_rho = self.sigma_min ** (1 / self.rho) max_inv_rho = self.sigma_max ** (1 / self.rho) @@ -40,6 +45,7 @@ class LegacyDDPMDiscretization(Discretization): linear_end=0.0120, num_timesteps=1000, ): + super().__init__() self.num_timesteps = num_timesteps betas = make_beta_schedule( "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end @@ -48,7 +54,7 @@ class LegacyDDPMDiscretization(Discretization): self.alphas_cumprod = np.cumprod(alphas, axis=0) self.to_torch = partial(torch.tensor, dtype=torch.float32) - def get_sigmas(self, n, device): + def get_sigmas(self, n, device="cpu"): if n < self.num_timesteps: timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps]