Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
import math
from abc import abstractmethod
from functools import partial
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
)
from ...util import default, exists
logger = logging.getLogger(__name__)
# dummy replace
def convert_module_to_f16(x):
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv:
logger.debug(
f"Building a Downsample layer with {dims} dims.\n"
print(f"Building a Downsample layer with {dims} dims.")
print(
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
f"kernel-size: 3, stride: {stride}, padding: {padding}"
)
if dims == 3:
logger.debug(f" --> Downsampling third axis (time): {third_down}")
print(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd(
dims,
self.channels,
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
range(len(num_attention_blocks)),
)
)
logger.warning(
print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set."
)
) # todo: convert to warning
self.attention_resolutions = attention_resolutions
self.dropout = dropout
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
if use_fp16:
logger.warning("use_fp16 was dropped and has no effect anymore.")
print("WARNING: use_fp16 was dropped and has no effect anymore.")
# self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous":
logger.debug("setting up linear c_adm embedding layer")
print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep":
self.label_emb = checkpoint_wrapper_fn(