Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,5 +1,6 @@
import functools
import importlib
import logging
import os
from functools import partial
from inspect import isfunction
@@ -10,6 +11,8 @@ import torch
from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors
logger = logging.getLogger(__name__)
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
@@ -86,7 +89,7 @@ def log_txt_as_img(wh, xc, size=10):
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
logger.warning("Cant encode string %r for logging. Skipping.", lines)
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
@@ -161,7 +164,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params
@@ -200,11 +203,11 @@ def append_dims(x, target_dims):
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
print(f"Loading model from {ckpt}")
logger.info(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
logger.debug(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt)
@@ -213,14 +216,13 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
missing, unexpected = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
if verbose:
if missing:
logger.info("missing keys: %r", missing)
if unexpected:
logger.info("unexpected keys: %r", unexpected)
if freeze:
for param in model.parameters():