mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 06:44:22 +01:00
Changed LegacyDDPMDiscretization for sampling
This commit is contained in:
@@ -325,10 +325,8 @@ def init_sampling(
|
|||||||
|
|
||||||
def get_discretization(discretization, key=1):
|
def get_discretization(discretization, key=1):
|
||||||
if discretization == "LegacyDDPMDiscretization":
|
if discretization == "LegacyDDPMDiscretization":
|
||||||
use_new_range = st.checkbox(f"Start from highest noise level? #{key}", False)
|
|
||||||
discretization_config = {
|
discretization_config = {
|
||||||
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
"target": "sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization",
|
||||||
"params": {"legacy_range": not use_new_range},
|
|
||||||
}
|
}
|
||||||
elif discretization == "EDMDiscretization":
|
elif discretization == "EDMDiscretization":
|
||||||
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
sigma_min = st.number_input(f"sigma_min #{key}", value=0.03) # 0.0292
|
||||||
|
|||||||
@@ -6,6 +6,26 @@ from ...util import append_zero
|
|||||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||||
|
|
||||||
|
|
||||||
|
def generate_roughly_equally_spaced_steps(n, m):
|
||||||
|
# 0, ..., m - 1
|
||||||
|
m -= 1
|
||||||
|
|
||||||
|
# We are getting rid of leading 0 later, so increase steps
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
# Calculate the step size
|
||||||
|
step = m / (n - 1)
|
||||||
|
|
||||||
|
# Generate the list
|
||||||
|
steps_reversed = [int(m - i * step) for i in range(n)]
|
||||||
|
steps = steps_reversed[::-1]
|
||||||
|
|
||||||
|
# Get rid of leading 0
|
||||||
|
steps = steps[1:]
|
||||||
|
|
||||||
|
return np.array(steps)
|
||||||
|
|
||||||
|
|
||||||
class Discretization:
|
class Discretization:
|
||||||
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
||||||
sigmas = self.get_sigmas(n, device)
|
sigmas = self.get_sigmas(n, device)
|
||||||
@@ -33,7 +53,6 @@ class LegacyDDPMDiscretization(Discretization):
|
|||||||
linear_start=0.00085,
|
linear_start=0.00085,
|
||||||
linear_end=0.0120,
|
linear_end=0.0120,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
legacy_range=True,
|
|
||||||
):
|
):
|
||||||
self.num_timesteps = num_timesteps
|
self.num_timesteps = num_timesteps
|
||||||
betas = make_beta_schedule(
|
betas = make_beta_schedule(
|
||||||
@@ -42,23 +61,15 @@ class LegacyDDPMDiscretization(Discretization):
|
|||||||
alphas = 1.0 - betas
|
alphas = 1.0 - betas
|
||||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
self.legacy_range = legacy_range
|
|
||||||
|
|
||||||
def get_sigmas(self, n, device):
|
def get_sigmas(self, n, device):
|
||||||
if n < self.num_timesteps:
|
if n < self.num_timesteps:
|
||||||
c = self.num_timesteps // n
|
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
||||||
|
|
||||||
if self.legacy_range:
|
|
||||||
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
|
|
||||||
timesteps += 1 # Legacy LDM Hack
|
|
||||||
else:
|
|
||||||
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
|
|
||||||
timesteps -= 1
|
|
||||||
timesteps = timesteps[1:]
|
|
||||||
|
|
||||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||||
else:
|
elif n == self.num_timesteps:
|
||||||
alphas_cumprod = self.alphas_cumprod
|
alphas_cumprod = self.alphas_cumprod
|
||||||
|
else:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||||
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||||
|
|||||||
Reference in New Issue
Block a user