Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
@@ -33,8 +32,6 @@ from ...util import (
instantiate_from_config,
)
logger = logging.getLogger(__name__)
class AbstractEmbModel(nn.Module):
def __init__(self):
@@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module):
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
logger.debug(
print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
@@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
logger.warning(
print(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
@@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
logger.debug(
print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
@@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module):
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
logger.debug(
print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(