Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,5 +1,4 @@
# pytorch_diffusion + derived encoder decoder
import logging
import math
from typing import Any, Callable, Optional
@@ -9,8 +8,6 @@ import torch.nn as nn
from einops import rearrange
from packaging import version
logger = logging.getLogger(__name__)
try:
import xformers
import xformers.ops
@@ -18,7 +15,7 @@ try:
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...")
print("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
logger.debug(
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
@@ -638,8 +633,10 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
logger.debug(
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
)
make_attn_cls = self._make_attn()