Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,3 +1,4 @@
import logging
from contextlib import nullcontext
from functools import partial
from typing import Dict, List, Optional, Tuple, Union
@@ -32,6 +33,8 @@ from ...util import (
instantiate_from_config,
)
logger = logging.getLogger(__name__)
class AbstractEmbModel(nn.Module):
def __init__(self):
@@ -96,7 +99,7 @@ class GeneralConditioner(nn.Module):
for param in embedder.parameters():
param.requires_grad = False
embedder.eval()
print(
logger.debug(
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
)
@@ -727,7 +730,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
)
if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
print(
logger.warning(
f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message."
)
@@ -753,7 +756,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
clip_version, device, max_length=clip_max_length
)
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
print(
logger.debug(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
)
@@ -795,7 +798,7 @@ class SpatialRescaler(nn.Module):
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output
if self.remap_output:
print(
logger.debug(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
)
self.channel_mapper = nn.Conv2d(