Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
import re
from abc import abstractmethod
from contextlib import contextmanager
@@ -15,8 +14,6 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config
logger = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule):
"""
@@ -41,7 +38,7 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -63,16 +60,16 @@ class AbstractAutoencoder(pl.LightningModule):
for k in keys:
for ik in ignore_keys:
if re.match(ik, k):
logger.debug(f"Deleting key {k} from state_dict.")
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
logger.debug(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")
@abstractmethod
def get_input(self, batch) -> Any:
@@ -89,14 +86,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")
@abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -107,7 +104,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg):
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict())
)