mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
60 lines
2.1 KiB
Python
60 lines
2.1 KiB
Python
import torch
|
|
|
|
from sgm.modules.diffusionmodules.discretizer import Discretization
|
|
|
|
|
|
class Img2ImgDiscretizationWrapper:
|
|
"""
|
|
wraps a discretizer, and prunes the sigmas
|
|
params:
|
|
strength: float between 0.0 and 1.0. 1.0 means full sampling (all sigmas are returned)
|
|
"""
|
|
|
|
def __init__(self, discretization: Discretization, strength: float = 1.0):
|
|
self.discretization = discretization
|
|
self.strength = strength
|
|
assert 0.0 <= self.strength <= 1.0
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# sigmas start large first, and decrease then
|
|
sigmas = self.discretization(*args, **kwargs)
|
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
|
sigmas = torch.flip(sigmas, (0,))
|
|
sigmas = sigmas[: max(int(self.strength * len(sigmas)), 1)]
|
|
print("prune index:", max(int(self.strength * len(sigmas)), 1))
|
|
sigmas = torch.flip(sigmas, (0,))
|
|
print(f"sigmas after pruning: ", sigmas)
|
|
return sigmas
|
|
|
|
|
|
class Txt2NoisyDiscretizationWrapper:
|
|
"""
|
|
wraps a discretizer, and prunes the sigmas
|
|
params:
|
|
strength: float between 0.0 and 1.0. 0.0 means full sampling (all sigmas are returned)
|
|
"""
|
|
|
|
def __init__(
|
|
self, discretization: Discretization, strength: float = 0.0, original_steps=None
|
|
):
|
|
self.discretization = discretization
|
|
self.strength = strength
|
|
self.original_steps = original_steps
|
|
assert 0.0 <= self.strength <= 1.0
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
# sigmas start large first, and decrease then
|
|
sigmas = self.discretization(*args, **kwargs)
|
|
print(f"sigmas after discretization, before pruning img2img: ", sigmas)
|
|
sigmas = torch.flip(sigmas, (0,))
|
|
if self.original_steps is None:
|
|
steps = len(sigmas)
|
|
else:
|
|
steps = self.original_steps + 1
|
|
prune_index = max(min(int(self.strength * steps) - 1, steps - 1), 0)
|
|
sigmas = sigmas[prune_index:]
|
|
print("prune index:", prune_index)
|
|
sigmas = torch.flip(sigmas, (0,))
|
|
print(f"sigmas after pruning: ", sigmas)
|
|
return sigmas
|