Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,6 +1,5 @@
import functools
import importlib
import logging
import os
from functools import partial
from inspect import isfunction
@@ -11,8 +10,6 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
logger = logging.getLogger(__name__)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
logger.warning("Cant encode string %r for logging. Skipping.", lines)
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -164,7 +161,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@@ -203,11 +200,11 @@ def append_dims(x, target_dims):
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
logger.info(f"Loading model from {ckpt}")
print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}")
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
model = instantiate_from_config(config.model)
missing, unexpected = model.load_state_dict(sd, strict=False)
m, u = model.load_state_dict(sd, strict=False)
if verbose:
if missing:
logger.info("missing keys: %r", missing)
if unexpected:
logger.info("unexpected keys: %r", unexpected)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if freeze:
for param in model.parameters():