Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,3 +1,4 @@
import logging
import re
from abc import abstractmethod
from contextlib import contextmanager
@@ -14,6 +15,8 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
logger = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
"""
@@ -38,7 +41,7 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -60,16 +63,16 @@ class AbstractAutoencoder(pl.LightningModule):
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
print("Deleting key {} from state_dict.".format(k))
logger.debug(f"Deleting key {k} from state_dict.")
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
logger.debug(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
logger.info(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
logger.info(f"Unexpected Keys: {unexpected}")
@abstractmethod
def get_input(self, batch) -> Any:
@@ -86,14 +89,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logger.info(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -104,7 +107,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
print(f"loading >>> {cfg['target']} <<< optimizer from config")
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)