Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
import math
from inspect import isfunction
from typing import Any, Optional
@@ -9,10 +8,6 @@ from einops import rearrange, repeat
from packaging import version
from torch import nn
logger = logging.getLogger(__name__)
if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel
@@ -41,9 +36,9 @@ else:
SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext
BACKEND_MAP = {}
logger.warning(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
)
try:
@@ -53,7 +48,7 @@ try:
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")
print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint
@@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
):
super().__init__()
logger.info(
print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}."
)
@@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module):
super().__init__()
assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logger.warning(
print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
)
attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
logger.warning(
print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
)
if not XFORMERS_IS_AVAILABLE:
raise NotImplementedError(
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
logger.info("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
assert (
False
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
else:
print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
@@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
if self.checkpoint:
logger.info(f"{self.__class__.__name__} is using checkpointing")
print(f"{self.__class__.__name__} is using checkpointing")
def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module):
sdp_backend=None,
):
super().__init__()
logger.debug(
print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
)
from omegaconf import ListConfig
@@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim):
logger.warning(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
print(
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
)
# depth does not match context dims.