Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,3 +1,4 @@
import logging
import math
from abc import abstractmethod
from functools import partial
@@ -21,6 +22,8 @@ from ...modules.diffusionmodules.util import (
)
from ...util import default, exists
logger = logging.getLogger(__name__)
# dummy replace
def convert_module_to_f16(x):
@@ -177,13 +180,13 @@ class Downsample(nn.Module):
self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv:
print(f"Building a Downsample layer with {dims} dims.")
print(
logger.debug(
f"Building a Downsample layer with {dims} dims.\n"
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
f"kernel-size: 3, stride: {stride}, padding: {padding}"
)
if dims == 3:
print(f" --> Downsampling third axis (time): {third_down}")
logger.debug(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd(
dims,
self.channels,
@@ -270,7 +273,7 @@ class ResBlock(TimestepBlock):
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
print(f"Skipping timestep embedding in {self.__class__.__name__}")
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
@@ -619,12 +622,12 @@ class UNetModel(nn.Module):
range(len(num_attention_blocks)),
)
)
print(
logger.warning(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
) # todo: convert to warning
)
self.attention_resolutions = attention_resolutions
self.dropout = dropout
@@ -633,7 +636,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
if use_fp16:
print("WARNING: use_fp16 was dropped and has no effect anymore.")
logger.warning("use_fp16 was dropped and has no effect anymore.")
# self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
@@ -664,7 +667,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
print("setting up linear c_adm embedding layer")
logger.debug("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = checkpoint_wrapper_fn(