Replace most print()s with logging calls (#42)

This commit is contained in:
Aarni Koskela
2023-07-25 16:21:30 +03:00
committed by GitHub
parent 6ecd0a900a
commit 6f6d3f8716
10 changed files with 118 additions and 92 deletions

View File

@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import logging
import math
from typing import Any, Callable, Optional
@@ -8,6 +9,8 @@ import torch.nn as nn
from einops import rearrange
from packaging import version
logger = logging.getLogger(__name__)
try:
import xformers
import xformers.ops
@@ -15,7 +18,7 @@ try:
XFORMERS_IS_AVAILABLE = True
except:
XFORMERS_IS_AVAILABLE = False
print("no module 'xformers'. Processing without...")
logger.debug("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
)
attn_type = "vanilla-xformers"
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers":
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
logger.debug(
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels
@@ -633,10 +638,8 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print(
"Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
logger.debug(
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
)
make_attn_cls = self._make_attn()