Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,3 +1,4 @@
import logging
from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union
@@ -18,6 +19,8 @@ from ..util import (
log_txt_as_img,
)
logger = logging.getLogger(__name__)
class DiffusionEngine(pl.LightningModule):
def __init__(
@@ -73,7 +76,7 @@ class DiffusionEngine(pl.LightningModule):
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -94,13 +97,13 @@ class DiffusionEngine(pl.LightningModule):
raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False)
print(
logger.info(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
)
if len(missing) > 0:
print(f"Missing Keys: {missing}")
logger.info(f"Missing Keys: {missing}")
if len(unexpected) > 0:
print(f"Unexpected Keys: {unexpected}")
logger.info(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config):
model = instantiate_from_config(config).eval()
@@ -179,14 +182,14 @@ class DiffusionEngine(pl.LightningModule):
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
logger.info(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
logger.info(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])(
@@ -202,7 +205,7 @@ class DiffusionEngine(pl.LightningModule):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
logger.debug("Setting up LambdaLR scheduler...")
scheduler = [
{
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
@@ -304,7 +307,7 @@ class DiffusionEngine(pl.LightningModule):
log["inputs"] = x
z = self.encode_first_stage(x)
log["reconstructions"] = self.decode_first_stage(z)
log.update(self.log_conditionings(batch, N))
logger.update(self.log_conditionings(batch, N))
for k in c:
if isinstance(c[k], torch.Tensor):