mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-20 14:54:21 +01:00
Replace most print()s with logging calls (#42)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -8,6 +9,8 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -15,7 +18,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
logger.debug("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
@@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
logger.debug(
|
||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
)
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
@@ -633,10 +638,8 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
logger.debug(
|
||||
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
|
||||
Reference in New Issue
Block a user