mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-21 07:14:21 +01:00
Fixing additional GPU memory on device 0 due to discretization
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
from ...util import append_zero
|
from ...util import append_zero
|
||||||
from ...modules.diffusionmodules.util import make_beta_schedule
|
from ...modules.diffusionmodules.util import make_beta_schedule
|
||||||
@@ -13,11 +14,15 @@ def generate_roughly_equally_spaced_steps(
|
|||||||
|
|
||||||
|
|
||||||
class Discretization:
|
class Discretization:
|
||||||
def __call__(self, n, do_append_zero=True, device="cuda", flip=False):
|
def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
|
||||||
sigmas = self.get_sigmas(n, device)
|
sigmas = self.get_sigmas(n, device=device)
|
||||||
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
||||||
return sigmas if not flip else torch.flip(sigmas, (0,))
|
return sigmas if not flip else torch.flip(sigmas, (0,))
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_sigmas(self, n, device):
|
||||||
|
raise NotImplementedError("abstract class should not be called")
|
||||||
|
|
||||||
|
|
||||||
class EDMDiscretization(Discretization):
|
class EDMDiscretization(Discretization):
|
||||||
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
||||||
@@ -25,7 +30,7 @@ class EDMDiscretization(Discretization):
|
|||||||
self.sigma_max = sigma_max
|
self.sigma_max = sigma_max
|
||||||
self.rho = rho
|
self.rho = rho
|
||||||
|
|
||||||
def get_sigmas(self, n, device):
|
def get_sigmas(self, n, device="cpu"):
|
||||||
ramp = torch.linspace(0, 1, n, device=device)
|
ramp = torch.linspace(0, 1, n, device=device)
|
||||||
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
min_inv_rho = self.sigma_min ** (1 / self.rho)
|
||||||
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
max_inv_rho = self.sigma_max ** (1 / self.rho)
|
||||||
@@ -40,6 +45,7 @@ class LegacyDDPMDiscretization(Discretization):
|
|||||||
linear_end=0.0120,
|
linear_end=0.0120,
|
||||||
num_timesteps=1000,
|
num_timesteps=1000,
|
||||||
):
|
):
|
||||||
|
super().__init__()
|
||||||
self.num_timesteps = num_timesteps
|
self.num_timesteps = num_timesteps
|
||||||
betas = make_beta_schedule(
|
betas = make_beta_schedule(
|
||||||
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
"linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
||||||
@@ -48,7 +54,7 @@ class LegacyDDPMDiscretization(Discretization):
|
|||||||
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||||
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||||
|
|
||||||
def get_sigmas(self, n, device):
|
def get_sigmas(self, n, device="cpu"):
|
||||||
if n < self.num_timesteps:
|
if n < self.num_timesteps:
|
||||||
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
||||||
alphas_cumprod = self.alphas_cumprod[timesteps]
|
alphas_cumprod = self.alphas_cumprod[timesteps]
|
||||||
|
|||||||
Reference in New Issue
Block a user