mirror of
https://github.com/Stability-AI/generative-models.git
synced 2026-02-02 04:14:27 +01:00
set a default scale
This commit is contained in:
@@ -68,7 +68,7 @@ class SamplingParams:
|
||||
discretization: Discretization = Discretization.LEGACY_DDPM
|
||||
guider: Guider = Guider.VANILLA
|
||||
thresholder: Thresholder = Thresholder.NONE
|
||||
scale: float
|
||||
scale: float = 5.0
|
||||
aesthetic_score: float = 6.0
|
||||
negative_aesthetic_score: float = 2.5
|
||||
img2img_strength: float = 1.0
|
||||
@@ -185,17 +185,13 @@ model_specs = {
|
||||
}
|
||||
|
||||
|
||||
def wrap_discretization(
|
||||
discretization, image_strength=None, noise_strength=None, steps=None
|
||||
):
|
||||
def wrap_discretization(discretization, image_strength=None, noise_strength=None, steps=None):
|
||||
if isinstance(discretization, Img2ImgDiscretizationWrapper) or isinstance(
|
||||
discretization, Txt2NoisyDiscretizationWrapper
|
||||
):
|
||||
return discretization # Already wrapped
|
||||
if image_strength is not None and image_strength < 1.0 and image_strength > 0.0:
|
||||
discretization = Img2ImgDiscretizationWrapper(
|
||||
discretization, strength=image_strength
|
||||
)
|
||||
discretization = Img2ImgDiscretizationWrapper(discretization, strength=image_strength)
|
||||
|
||||
if (
|
||||
noise_strength is not None
|
||||
@@ -249,19 +245,13 @@ class SamplingPipeline:
|
||||
self.config = os.path.join(config_path, "inference", self.specs.config)
|
||||
self.ckpt = os.path.join(model_path, self.specs.ckpt)
|
||||
if not os.path.exists(self.config):
|
||||
raise ValueError(
|
||||
f"Config {self.config} not found, check model spec or config_path"
|
||||
)
|
||||
raise ValueError(f"Config {self.config} not found, check model spec or config_path")
|
||||
if not os.path.exists(self.ckpt):
|
||||
raise ValueError(
|
||||
f"Checkpoint {self.ckpt} not found, check model spec or config_path"
|
||||
)
|
||||
raise ValueError(f"Checkpoint {self.ckpt} not found, check model spec or config_path")
|
||||
|
||||
self.device_manager = get_model_manager(device)
|
||||
|
||||
self.model = self._load_model(
|
||||
device_manager=self.device_manager, use_fp16=use_fp16
|
||||
)
|
||||
self.model = self._load_model(device_manager=self.device_manager, use_fp16=use_fp16)
|
||||
|
||||
def _load_model(self, device_manager: DeviceModelManager, use_fp16=True):
|
||||
config = OmegaConf.load(self.config)
|
||||
@@ -406,9 +396,7 @@ class SamplingPipeline:
|
||||
def get_guider_config(params: SamplingParams) -> Dict[str, Any]:
|
||||
guider_config: Dict[str, Any]
|
||||
if params.guider == Guider.IDENTITY:
|
||||
guider_config = {
|
||||
"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"
|
||||
}
|
||||
guider_config = {"target": "sgm.modules.diffusionmodules.guiders.IdentityGuider"}
|
||||
elif params.guider == Guider.VANILLA:
|
||||
scale = params.scale
|
||||
|
||||
|
||||
Reference in New Issue
Block a user