Revert "Replace most print()s with logging calls (#42)" (#65)

This reverts commit 6f6d3f8716.
This commit is contained in:
Jonas Müller
2023-07-26 10:30:21 +02:00
committed by GitHub
parent 7934245835
commit 4a3f0f546e
10 changed files with 91 additions and 117 deletions

View File

@@ -1,4 +1,3 @@
import logging
from typing import Optional from typing import Optional
import torchdata.datapipes.iter import torchdata.datapipes.iter
@@ -6,17 +5,16 @@ import webdataset as wds
from omegaconf import DictConfig from omegaconf import DictConfig
from pytorch_lightning import LightningDataModule from pytorch_lightning import LightningDataModule
logger = logging.getLogger(__name__)
try: try:
from sdata import create_dataset, create_dummy_dataset, create_loader from sdata import create_dataset, create_dummy_dataset, create_loader
except ImportError as e: except ImportError as e:
raise NotImplementedError( print("#" * 100)
"Datasets not yet available. " print("Datasets not yet available")
"To enable, we need to add stable-datasets as a submodule; " print("to enable, we need to add stable-datasets as a submodule")
"please use ``git submodule update --init --recursive`` " print("please use ``git submodule update --init --recursive``")
"and do ``pip install -e stable-datasets/`` from the root of this repo" print("and do ``pip install -e stable-datasets/`` from the root of this repo")
) from e print("#" * 100)
exit(1)
class StableDataModuleFromConfig(LightningDataModule): class StableDataModuleFromConfig(LightningDataModule):
@@ -41,8 +39,8 @@ class StableDataModuleFromConfig(LightningDataModule):
"datapipeline" in self.val_config and "loader" in self.val_config "datapipeline" in self.val_config and "loader" in self.val_config
), "validation config requires the fields `datapipeline` and `loader`" ), "validation config requires the fields `datapipeline` and `loader`"
else: else:
logger.warning( print(
"No Validation datapipeline defined, using that one from training" "Warning: No Validation datapipeline defined, using that one from training"
) )
self.val_config = train self.val_config = train
@@ -54,10 +52,12 @@ class StableDataModuleFromConfig(LightningDataModule):
self.dummy = dummy self.dummy = dummy
if self.dummy: if self.dummy:
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") print("#" * 100)
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
print("#" * 100)
def setup(self, stage: str) -> None: def setup(self, stage: str) -> None:
logger.debug("Preparing datasets") print("Preparing datasets")
if self.dummy: if self.dummy:
data_fn = create_dummy_dataset data_fn = create_dummy_dataset
else: else:

View File

@@ -1,9 +1,5 @@
import logging
import numpy as np import numpy as np
logger = logging.getLogger(__name__)
class LambdaWarmUpCosineScheduler: class LambdaWarmUpCosineScheduler:
""" """
@@ -28,8 +24,9 @@ class LambdaWarmUpCosineScheduler:
self.verbosity_interval = verbosity_interval self.verbosity_interval = verbosity_interval
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: if self.verbosity_interval > 0:
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n % self.verbosity_interval == 0:
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
if n < self.lr_warm_up_steps: if n < self.lr_warm_up_steps:
lr = ( lr = (
self.lr_max - self.lr_start self.lr_max - self.lr_start
@@ -86,11 +83,12 @@ class LambdaWarmUpCosineScheduler2:
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: if self.verbosity_interval > 0:
logger.info( if n % self.verbosity_interval == 0:
f"current step: {n}, recent lr-multiplier: {self.last_f}, " print(
f"current cycle {cycle}" f"current step: {n}, recent lr-multiplier: {self.last_f}, "
) f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
cycle cycle
@@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
def schedule(self, n, **kwargs): def schedule(self, n, **kwargs):
cycle = self.find_in_interval(n) cycle = self.find_in_interval(n)
n = n - self.cum_cycles[cycle] n = n - self.cum_cycles[cycle]
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: if self.verbosity_interval > 0:
logger.info( if n % self.verbosity_interval == 0:
f"current step: {n}, recent lr-multiplier: {self.last_f}, " print(
f"current cycle {cycle}" f"current step: {n}, recent lr-multiplier: {self.last_f}, "
) f"current cycle {cycle}"
)
if n < self.lr_warm_up_steps[cycle]: if n < self.lr_warm_up_steps[cycle]:
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[

View File

@@ -1,4 +1,3 @@
import logging
import re import re
from abc import abstractmethod from abc import abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
@@ -15,8 +14,6 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
from ..modules.ema import LitEma from ..modules.ema import LitEma
from ..util import default, get_obj_from_str, instantiate_from_config from ..util import default, get_obj_from_str, instantiate_from_config
logger = logging.getLogger(__name__)
class AbstractAutoencoder(pl.LightningModule): class AbstractAutoencoder(pl.LightningModule):
""" """
@@ -41,7 +38,7 @@ class AbstractAutoencoder(pl.LightningModule):
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None: if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
@@ -63,16 +60,16 @@ class AbstractAutoencoder(pl.LightningModule):
for k in keys: for k in keys:
for ik in ignore_keys: for ik in ignore_keys:
if re.match(ik, k): if re.match(ik, k):
logger.debug(f"Deleting key {k} from state_dict.") print("Deleting key {} from state_dict.".format(k))
del sd[k] del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
logger.debug( print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
) )
if len(missing) > 0: if len(missing) > 0:
logger.info(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
@abstractmethod @abstractmethod
def get_input(self, batch) -> Any: def get_input(self, batch) -> Any:
@@ -89,14 +86,14 @@ class AbstractAutoencoder(pl.LightningModule):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
logger.info(f"{context}: Switched to EMA weights") print(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
logger.info(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
@abstractmethod @abstractmethod
def encode(self, *args, **kwargs) -> torch.Tensor: def encode(self, *args, **kwargs) -> torch.Tensor:
@@ -107,7 +104,7 @@ class AbstractAutoencoder(pl.LightningModule):
raise NotImplementedError("decode()-method of abstract base class called") raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config") print(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict()) params, lr=lr, **cfg.get("params", dict())
) )

View File

@@ -1,4 +1,3 @@
import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
@@ -19,8 +18,6 @@ from ..util import (
log_txt_as_img, log_txt_as_img,
) )
logger = logging.getLogger(__name__)
class DiffusionEngine(pl.LightningModule): class DiffusionEngine(pl.LightningModule):
def __init__( def __init__(
@@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule):
self.use_ema = use_ema self.use_ema = use_ema
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self.model, decay=ema_decay_rate) self.model_ema = LitEma(self.model, decay=ema_decay_rate)
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
self.scale_factor = scale_factor self.scale_factor = scale_factor
self.disable_first_stage_autocast = disable_first_stage_autocast self.disable_first_stage_autocast = disable_first_stage_autocast
@@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule):
raise NotImplementedError raise NotImplementedError
missing, unexpected = self.load_state_dict(sd, strict=False) missing, unexpected = self.load_state_dict(sd, strict=False)
logger.info( print(
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
) )
if len(missing) > 0: if len(missing) > 0:
logger.info(f"Missing Keys: {missing}") print(f"Missing Keys: {missing}")
if len(unexpected) > 0: if len(unexpected) > 0:
logger.info(f"Unexpected Keys: {unexpected}") print(f"Unexpected Keys: {unexpected}")
def _init_first_stage(self, config): def _init_first_stage(self, config):
model = instantiate_from_config(config).eval() model = instantiate_from_config(config).eval()
@@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule):
self.model_ema.store(self.model.parameters()) self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model) self.model_ema.copy_to(self.model)
if context is not None: if context is not None:
logger.info(f"{context}: Switched to EMA weights") print(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.model.parameters()) self.model_ema.restore(self.model.parameters())
if context is not None: if context is not None:
logger.info(f"{context}: Restored training weights") print(f"{context}: Restored training weights")
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(
@@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule):
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
if self.scheduler_config is not None: if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config) scheduler = instantiate_from_config(self.scheduler_config)
logger.debug("Setting up LambdaLR scheduler...") print("Setting up LambdaLR scheduler...")
scheduler = [ scheduler = [
{ {
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),

View File

@@ -1,4 +1,3 @@
import logging
import math import math
from inspect import isfunction from inspect import isfunction
from typing import Any, Optional from typing import Any, Optional
@@ -9,10 +8,6 @@ from einops import rearrange, repeat
from packaging import version from packaging import version
from torch import nn from torch import nn
logger = logging.getLogger(__name__)
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
SDP_IS_AVAILABLE = True SDP_IS_AVAILABLE = True
from torch.backends.cuda import SDPBackend, sdp_kernel from torch.backends.cuda import SDPBackend, sdp_kernel
@@ -41,9 +36,9 @@ else:
SDP_IS_AVAILABLE = False SDP_IS_AVAILABLE = False
sdp_kernel = nullcontext sdp_kernel = nullcontext
BACKEND_MAP = {} BACKEND_MAP = {}
logger.warning( print(
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. " f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading." f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
) )
try: try:
@@ -53,7 +48,7 @@ try:
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...") print("no module 'xformers'. Processing without...")
from .diffusionmodules.util import checkpoint from .diffusionmodules.util import checkpoint
@@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module):
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
): ):
super().__init__() super().__init__()
logger.info( print(
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
f"{heads} heads with a dimension of {dim_head}." f"{heads} heads with a dimension of {dim_head}."
) )
@@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module):
super().__init__() super().__init__()
assert attn_mode in self.ATTENTION_MODES assert attn_mode in self.ATTENTION_MODES
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
logger.warning( print(
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
) )
attn_mode = "softmax" attn_mode = "softmax"
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
logger.warning( print(
"We do not support vanilla attention anymore, as it is too expensive. Sorry." "We do not support vanilla attention anymore, as it is too expensive. Sorry."
) )
if not XFORMERS_IS_AVAILABLE: if not XFORMERS_IS_AVAILABLE:
raise NotImplementedError( assert (
"Please install xformers via e.g. 'pip install xformers==0.0.16'" False
) ), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
logger.info("Falling back to xformers efficient attention.") else:
attn_mode = "softmax-xformers" print("Falling back to xformers efficient attention.")
attn_mode = "softmax-xformers"
attn_cls = self.ATTENTION_MODES[attn_mode] attn_cls = self.ATTENTION_MODES[attn_mode]
if version.parse(torch.__version__) >= version.parse("2.0.0"): if version.parse(torch.__version__) >= version.parse("2.0.0"):
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
@@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module):
self.norm3 = nn.LayerNorm(dim) self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint self.checkpoint = checkpoint
if self.checkpoint: if self.checkpoint:
logger.info(f"{self.__class__.__name__} is using checkpointing") print(f"{self.__class__.__name__} is using checkpointing")
def forward( def forward(
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
@@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module):
sdp_backend=None, sdp_backend=None,
): ):
super().__init__() super().__init__()
logger.debug( print(
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
) )
from omegaconf import ListConfig from omegaconf import ListConfig
@@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module):
context_dim = [context_dim] context_dim = [context_dim]
if exists(context_dim) and isinstance(context_dim, list): if exists(context_dim) and isinstance(context_dim, list):
if depth != len(context_dim): if depth != len(context_dim):
logger.warning( print(
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
) )
# depth does not match context dims. # depth does not match context dims.

View File

@@ -1,4 +1,3 @@
import logging
from typing import Any, Union from typing import Any, Union
import torch import torch
@@ -11,9 +10,6 @@ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
from ....util import default, instantiate_from_config from ....util import default, instantiate_from_config
logger = logging.getLogger(__name__)
def adopt_weight(weight, global_step, threshold=0, value=0.0): def adopt_weight(weight, global_step, threshold=0, value=0.0):
if global_step < threshold: if global_step < threshold:
weight = value weight = value
@@ -108,7 +104,7 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
super().__init__() super().__init__()
self.dims = dims self.dims = dims
if self.dims > 2: if self.dims > 2:
logger.info( print(
f"running with dims={dims}. This means that for perceptual loss calculation, " f"running with dims={dims}. This means that for perceptual loss calculation, "
f"the LPIPS loss will be applied to each frame independently. " f"the LPIPS loss will be applied to each frame independently. "
) )

View File

@@ -1,5 +1,4 @@
# pytorch_diffusion + derived encoder decoder # pytorch_diffusion + derived encoder decoder
import logging
import math import math
from typing import Any, Callable, Optional from typing import Any, Callable, Optional
@@ -9,8 +8,6 @@ import torch.nn as nn
from einops import rearrange from einops import rearrange
from packaging import version from packaging import version
logger = logging.getLogger(__name__)
try: try:
import xformers import xformers
import xformers.ops import xformers.ops
@@ -18,7 +15,7 @@ try:
XFORMERS_IS_AVAILABLE = True XFORMERS_IS_AVAILABLE = True
except: except:
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
logger.debug("no module 'xformers'. Processing without...") print("no module 'xformers'. Processing without...")
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
) )
attn_type = "vanilla-xformers" attn_type = "vanilla-xformers"
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla": if attn_type == "vanilla":
assert attn_kwargs is None assert attn_kwargs is None
return AttnBlock(in_channels) return AttnBlock(in_channels)
elif attn_type == "vanilla-xformers": elif attn_type == "vanilla-xformers":
logger.debug( print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
)
return MemoryEfficientAttnBlock(in_channels) return MemoryEfficientAttnBlock(in_channels)
elif type == "memory-efficient-cross-attn": elif type == "memory-efficient-cross-attn":
attn_kwargs["query_dim"] = in_channels attn_kwargs["query_dim"] = in_channels
@@ -638,8 +633,10 @@ class Decoder(nn.Module):
block_in = ch * ch_mult[self.num_resolutions - 1] block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1) curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res) self.z_shape = (1, z_channels, curr_res, curr_res)
logger.debug( print(
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." "Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)
)
) )
make_attn_cls = self._make_attn() make_attn_cls = self._make_attn()

View File

@@ -1,4 +1,3 @@
import logging
import math import math
from abc import abstractmethod from abc import abstractmethod
from functools import partial from functools import partial
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
) )
from ...util import default, exists from ...util import default, exists
logger = logging.getLogger(__name__)
# dummy replace # dummy replace
def convert_module_to_f16(x): def convert_module_to_f16(x):
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
self.dims = dims self.dims = dims
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
if use_conv: if use_conv:
logger.debug( print(f"Building a Downsample layer with {dims} dims.")
f"Building a Downsample layer with {dims} dims.\n" print(
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
f"kernel-size: 3, stride: {stride}, padding: {padding}" f"kernel-size: 3, stride: {stride}, padding: {padding}"
) )
if dims == 3: if dims == 3:
logger.debug(f" --> Downsampling third axis (time): {third_down}") print(f" --> Downsampling third axis (time): {third_down}")
self.op = conv_nd( self.op = conv_nd(
dims, dims,
self.channels, self.channels,
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
2 * self.out_channels if use_scale_shift_norm else self.out_channels 2 * self.out_channels if use_scale_shift_norm else self.out_channels
) )
if self.skip_t_emb: if self.skip_t_emb:
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}") print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm assert not self.use_scale_shift_norm
self.emb_layers = None self.emb_layers = None
self.exchange_temb_dims = False self.exchange_temb_dims = False
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
range(len(num_attention_blocks)), range(len(num_attention_blocks)),
) )
) )
logger.warning( print(
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set." f"attention will still not be set."
) ) # todo: convert to warning
self.attention_resolutions = attention_resolutions self.attention_resolutions = attention_resolutions
self.dropout = dropout self.dropout = dropout
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.use_checkpoint = use_checkpoint self.use_checkpoint = use_checkpoint
if use_fp16: if use_fp16:
logger.warning("use_fp16 was dropped and has no effect anymore.") print("WARNING: use_fp16 was dropped and has no effect anymore.")
# self.dtype = th.float16 if use_fp16 else th.float32 # self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads self.num_heads = num_heads
self.num_head_channels = num_head_channels self.num_head_channels = num_head_channels
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
if isinstance(self.num_classes, int): if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim) self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == "continuous": elif self.num_classes == "continuous":
logger.debug("setting up linear c_adm embedding layer") print("setting up linear c_adm embedding layer")
self.label_emb = nn.Linear(1, time_embed_dim) self.label_emb = nn.Linear(1, time_embed_dim)
elif self.num_classes == "timestep": elif self.num_classes == "timestep":
self.label_emb = checkpoint_wrapper_fn( self.label_emb = checkpoint_wrapper_fn(

View File

@@ -1,4 +1,3 @@
import logging
from contextlib import nullcontext from contextlib import nullcontext
from functools import partial from functools import partial
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@@ -33,8 +32,6 @@ from ...util import (
instantiate_from_config, instantiate_from_config,
) )
logger = logging.getLogger(__name__)
class AbstractEmbModel(nn.Module): class AbstractEmbModel(nn.Module):
def __init__(self): def __init__(self):
@@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module):
for param in embedder.parameters(): for param in embedder.parameters():
param.requires_grad = False param.requires_grad = False
embedder.eval() embedder.eval()
logger.debug( print(
f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"Initialized embedder #{n}: {embedder.__class__.__name__} "
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
) )
@@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
) )
if tokens is not None: if tokens is not None:
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
logger.warning( print(
f"You are running very experimental token-concat in {self.__class__.__name__}. " f"You are running very experimental token-concat in {self.__class__.__name__}. "
f"Check what you are doing, and then remove this message." f"Check what you are doing, and then remove this message."
) )
@@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
clip_version, device, max_length=clip_max_length clip_version, device, max_length=clip_max_length
) )
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
logger.debug( print(
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
) )
@@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module):
self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None or remap_output self.remap_output = out_channels is not None or remap_output
if self.remap_output: if self.remap_output:
logger.debug( print(
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
) )
self.channel_mapper = nn.Conv2d( self.channel_mapper = nn.Conv2d(

View File

@@ -1,6 +1,5 @@
import functools import functools
import importlib import importlib
import logging
import os import os
from functools import partial from functools import partial
from inspect import isfunction from inspect import isfunction
@@ -11,8 +10,6 @@ import torch
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from safetensors.torch import load_file as load_safetensors from safetensors.torch import load_file as load_safetensors
logger = logging.getLogger(__name__)
def disabled_train(self, mode=True): def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode """Overwrite model.train with this function to make sure train/eval mode
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
try: try:
draw.text((0, 0), lines, fill="black", font=font) draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError: except UnicodeEncodeError:
logger.warning("Cant encode string %r for logging. Skipping.", lines) print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt) txts.append(txt)
@@ -164,7 +161,7 @@ def mean_flat(tensor):
def count_params(model, verbose=False): def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters()) total_params = sum(p.numel() for p in model.parameters())
if verbose: if verbose:
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
return total_params return total_params
@@ -203,11 +200,11 @@ def append_dims(x, target_dims):
def load_model_from_config(config, ckpt, verbose=True, freeze=True): def load_model_from_config(config, ckpt, verbose=True, freeze=True):
logger.info(f"Loading model from {ckpt}") print(f"Loading model from {ckpt}")
if ckpt.endswith("ckpt"): if ckpt.endswith("ckpt"):
pl_sd = torch.load(ckpt, map_location="cpu") pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd: if "global_step" in pl_sd:
logger.debug(f"Global Step: {pl_sd['global_step']}") print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"] sd = pl_sd["state_dict"]
elif ckpt.endswith("safetensors"): elif ckpt.endswith("safetensors"):
sd = load_safetensors(ckpt) sd = load_safetensors(ckpt)
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
model = instantiate_from_config(config.model) model = instantiate_from_config(config.model)
missing, unexpected = model.load_state_dict(sd, strict=False) m, u = model.load_state_dict(sd, strict=False)
if verbose: if len(m) > 0 and verbose:
if missing: print("missing keys:")
logger.info("missing keys: %r", missing) print(m)
if unexpected: if len(u) > 0 and verbose:
logger.info("unexpected keys: %r", unexpected) print("unexpected keys:")
print(u)
if freeze: if freeze:
for param in model.parameters(): for param in model.parameters():