mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
24
sgm/util.py
24
sgm/util.py
@@ -1,6 +1,5 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
@@ -11,8 +10,6 @@ import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Cant encode string %r for logging. Skipping.", lines)
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -164,7 +161,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
@@ -203,11 +200,11 @@ def append_dims(x, target_dims):
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
logger.info(f"Loading model from {ckpt}")
|
||||
print(f"Loading model from {ckpt}")
|
||||
if ckpt.endswith("ckpt"):
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
elif ckpt.endswith("safetensors"):
|
||||
sd = load_safetensors(ckpt)
|
||||
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if verbose:
|
||||
if missing:
|
||||
logger.info("missing keys: %r", missing)
|
||||
if unexpected:
|
||||
logger.info("unexpected keys: %r", unexpected)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
if freeze:
|
||||
for param in model.parameters():
|
||||
|
||||
Reference in New Issue
Block a user