mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 06:14:21 +01:00
soon is now
This commit is contained in:
65
sgm/modules/diffusionmodules/discretizer.py
Normal file
65
sgm/modules/diffusionmodules/discretizer.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
from ...util import append_zero
|
||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||
|
||||
|
||||
class Discretization:
|
||||
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
||||
sigmas = self.get_sigmas(n, device)
|
||||
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
||||
return sigmas if not flip else torch.flip(sigmas, (0,))
|
||||
|
||||
|
||||
class EDMDiscretization(Discretization):
|
||||
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
||||
self.sigma_min = sigma_min
|
||||
self.sigma_max = sigma_max
|
||||
self.rho = rho
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
ramp = torch.linspace(0, 1, n, device=device)
|
||||
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
||||
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
||||
return sigmas
|
||||
|
||||
|
||||
class LegacyDDPMDiscretization(Discretization):
|
||||
def __init__(
|
||||
self,
|
||||
linear_start=0.00085,
|
||||
linear_end=0.0120,
|
||||
num_timesteps=1000,
|
||||
legacy_range=True,
|
||||
):
|
||||
self.num_timesteps = num_timesteps
|
||||
betas = make_beta_schedule(
|
||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||
)
|
||||
alphas = 1.0 - betas
|
||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
self.legacy_range = legacy_range
|
||||
|
||||
def get_sigmas(self, n, device):
|
||||
if n < self.num_timesteps:
|
||||
c = self.num_timesteps // n
|
||||
|
||||
if self.legacy_range:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps, c)))
|
||||
timesteps += 1 # Legacy LDM Hack
|
||||
else:
|
||||
timesteps = np.asarray(list(range(0, self.num_timesteps + 1, c)))
|
||||
timesteps -= 1
|
||||
timesteps = timesteps[1:]
|
||||
|
||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||
else:
|
||||
alphas_cumprod = self.alphas_cumprod
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
||||
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
return torch.flip(sigmas, (0,))
|
||||
Reference in New Issue
Block a user