Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union
@@ -19,8 +18,6 @@ from ..util import (
log_txt_as_img,
)
logger = logging.getLogger(__name__)
class DiffusionEngine(pl.LightningModule):
def __init__(
@@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule):
raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False)
logger.info(
print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
logger.info(f"Missing Keys: {missing}")
print(f"Missing Keys: {missing}")
if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}")
print(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
@@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
logger.info(f"{context}: Switched to EMA weights")
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
logger.info(f"{context}: Restored training weights")
print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
@@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
logger.debug("Setting up LambdaLR scheduler...")
print("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),