mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-22 07:44:22 +01:00
soon is now
This commit is contained in:
231
sgm/util.py
Normal file
231
sgm/util.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
def get_string_from_tuple(s):
|
||||
try:
|
||||
# Check if the string starts and ends with parentheses
|
||||
if s[0] == "(" and s[-1] == ")":
|
||||
# Convert the string to a tuple
|
||||
t = eval(s)
|
||||
# Check if the type of t is tuple
|
||||
if type(t) == tuple:
|
||||
return t[0]
|
||||
else:
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
return s
|
||||
|
||||
|
||||
def is_power_of_two(n):
|
||||
"""
|
||||
chat.openai.com/chat
|
||||
Return True if n is a power of 2, otherwise return False.
|
||||
|
||||
The function is_power_of_two takes an integer n as input and returns True if n is a power of 2, otherwise it returns False.
|
||||
The function works by first checking if n is less than or equal to 0. If n is less than or equal to 0, it can't be a power of 2, so the function returns False.
|
||||
If n is greater than 0, the function checks whether n is a power of 2 by using a bitwise AND operation between n and n-1. If n is a power of 2, then it will have only one bit set to 1 in its binary representation. When we subtract 1 from a power of 2, all the bits to the right of that bit become 1, and the bit itself becomes 0. So, when we perform a bitwise AND between n and n-1, we get 0 if n is a power of 2, and a non-zero value otherwise.
|
||||
Thus, if the result of the bitwise AND operation is 0, then n is a power of 2 and the function returns True. Otherwise, the function returns False.
|
||||
|
||||
"""
|
||||
if n <= 0:
|
||||
return False
|
||||
return (n & (n - 1)) == 0
|
||||
|
||||
|
||||
def autocast(f, enabled=True):
|
||||
def do_autocast(*args, **kwargs):
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=enabled,
|
||||
dtype=torch.get_autocast_gpu_dtype(),
|
||||
cache_enabled=torch.is_autocast_cache_enabled(),
|
||||
):
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return do_autocast
|
||||
|
||||
|
||||
def load_partial_from_config(config):
|
||||
return partial(get_obj_from_str(config["target"]), **config.get("params", dict()))
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
if isinstance(xc[bi], list):
|
||||
text_seq = xc[bi][0]
|
||||
else:
|
||||
text_seq = xc[bi]
|
||||
lines = "\n".join(
|
||||
text_seq[start : start + nc] for start in range(0, len(text_seq), nc)
|
||||
)
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def partialclass(cls, *args, **kwargs):
|
||||
class NewCls(cls):
|
||||
__init__ = functools.partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
||||
|
||||
def make_path_absolute(path):
|
||||
fs, p = fsspec.core.url_to_fs(path)
|
||||
if fs.protocol == "file":
|
||||
return os.path.abspath(p)
|
||||
return path
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def isheatmap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
|
||||
return x.ndim == 2
|
||||
|
||||
|
||||
def isneighbors(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return x.ndim == 5 and (x.shape[2] == 3 or x.shape[2] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def expand_dims_like(x, y):
|
||||
while x.dim() != y.dim():
|
||||
x = x.unsqueeze(-1)
|
||||
return x
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == "__is_first_stage__":
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False, invalidate_cache=True):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if invalidate_cache:
|
||||
importlib.invalidate_caches()
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
||||
)
|
||||
return x[(...,) + (None,) * dims_to_append]
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
print(f"Loading model from {ckpt}")
|
||||
if ckpt.endswith("ckpt"):
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
elif ckpt.endswith("safetensors"):
|
||||
sd = load_safetensors(ckpt)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
sd = pl_sd["state_dict"]
|
||||
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
if freeze:
|
||||
for param in model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
model.eval()
|
||||
return model
|
||||
Reference in New Issue
Block a user