diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py index b726149..ffa9ab8 100644 --- a/sgm/data/dataset.py +++ b/sgm/data/dataset.py @@ -1,3 +1,4 @@ +import logging from typing import Optional import torchdata.datapipes.iter @@ -5,16 +6,17 @@ import webdataset as wds from omegaconf import DictConfig from pytorch_lightning import LightningDataModule +logger = logging.getLogger(__name__) + try: from sdata import create_dataset, create_dummy_dataset, create_loader except ImportError as e: - print("#" * 100) - print("Datasets not yet available") - print("to enable, we need to add stable-datasets as a submodule") - print("please use ``git submodule update --init --recursive``") - print("and do ``pip install -e stable-datasets/`` from the root of this repo") - print("#" * 100) - exit(1) + raise NotImplementedError( + "Datasets not yet available. " + "To enable, we need to add stable-datasets as a submodule; " + "please use ``git submodule update --init --recursive`` " + "and do ``pip install -e stable-datasets/`` from the root of this repo" + ) from e class StableDataModuleFromConfig(LightningDataModule): @@ -39,8 +41,8 @@ class StableDataModuleFromConfig(LightningDataModule): "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: - print( - "Warning: No Validation datapipeline defined, using that one from training" + logger.warning( + "No Validation datapipeline defined, using that one from training" ) self.val_config = train @@ -52,12 +54,10 @@ class StableDataModuleFromConfig(LightningDataModule): self.dummy = dummy if self.dummy: - print("#" * 100) - print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") - print("#" * 100) + logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") def setup(self, stage: str) -> None: - print("Preparing datasets") + logger.debug("Preparing datasets") if self.dummy: data_fn = create_dummy_dataset else: diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py index b2f4d38..33a0ac4 100644 --- a/sgm/lr_scheduler.py +++ b/sgm/lr_scheduler.py @@ -1,5 +1,9 @@ +import logging + import numpy as np +logger = logging.getLogger(__name__) + class LambdaWarmUpCosineScheduler: """ @@ -24,9 +28,8 @@ class LambdaWarmUpCosineScheduler: self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start @@ -83,12 +86,11 @@ class LambdaWarmUpCosineScheduler2: def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + logger.info( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle @@ -114,12 +116,11 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0: - if n % self.verbosity_interval == 0: - print( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: + logger.info( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index 78fb551..c7114af 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -1,3 +1,4 @@ +import logging import re from abc import abstractmethod from contextlib import contextmanager @@ -14,6 +15,8 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution from ..modules.ema import LitEma from ..util import default, get_obj_from_str, instantiate_from_config +logger = logging.getLogger(__name__) + class AbstractAutoencoder(pl.LightningModule): """ @@ -38,7 +41,7 @@ class AbstractAutoencoder(pl.LightningModule): if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -60,16 +63,16 @@ class AbstractAutoencoder(pl.LightningModule): for k in keys: for ik in ignore_keys: if re.match(ik, k): - print("Deleting key {} from state_dict.".format(k)) + logger.debug(f"Deleting key {k} from state_dict.") del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) - print( + logger.debug( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: - print(f"Missing Keys: {missing}") + logger.info(f"Missing Keys: {missing}") if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + logger.info(f"Unexpected Keys: {unexpected}") @abstractmethod def get_input(self, batch) -> Any: @@ -86,14 +89,14 @@ class AbstractAutoencoder(pl.LightningModule): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - print(f"{context}: Switched to EMA weights") + logger.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - print(f"{context}: Restored training weights") + logger.info(f"{context}: Restored training weights") @abstractmethod def encode(self, *args, **kwargs) -> torch.Tensor: @@ -104,7 +107,7 @@ class AbstractAutoencoder(pl.LightningModule): raise NotImplementedError("decode()-method of abstract base class called") def instantiate_optimizer_from_config(self, params, lr, cfg): - print(f"loading >>> {cfg['target']} <<< optimizer from config") + logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config") return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index e1f1397..8b7cc17 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -1,3 +1,4 @@ +import logging from contextlib import contextmanager from typing import Any, Dict, List, Tuple, Union @@ -18,6 +19,8 @@ from ..util import ( log_txt_as_img, ) +logger = logging.getLogger(__name__) + class DiffusionEngine(pl.LightningModule): def __init__( @@ -73,7 +76,7 @@ class DiffusionEngine(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model, decay=ema_decay_rate) - print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast @@ -94,13 +97,13 @@ class DiffusionEngine(pl.LightningModule): raise NotImplementedError missing, unexpected = self.load_state_dict(sd, strict=False) - print( + logger.info( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: - print(f"Missing Keys: {missing}") + logger.info(f"Missing Keys: {missing}") if len(unexpected) > 0: - print(f"Unexpected Keys: {unexpected}") + logger.info(f"Unexpected Keys: {unexpected}") def _init_first_stage(self, config): model = instantiate_from_config(config).eval() @@ -179,14 +182,14 @@ class DiffusionEngine(pl.LightningModule): self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: - print(f"{context}: Switched to EMA weights") + logger.info(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: - print(f"{context}: Restored training weights") + logger.info(f"{context}: Restored training weights") def instantiate_optimizer_from_config(self, params, lr, cfg): return get_obj_from_str(cfg["target"])( @@ -202,7 +205,7 @@ class DiffusionEngine(pl.LightningModule): opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) - print("Setting up LambdaLR scheduler...") + logger.debug("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), @@ -304,7 +307,7 @@ class DiffusionEngine(pl.LightningModule): log["inputs"] = x z = self.encode_first_stage(x) log["reconstructions"] = self.decode_first_stage(z) - log.update(self.log_conditionings(batch, N)) + logger.update(self.log_conditionings(batch, N)) for k in c: if isinstance(c[k], torch.Tensor): diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index a17edda..38af6e6 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -1,3 +1,4 @@ +import logging import math from inspect import isfunction from typing import Any, Optional @@ -8,6 +9,10 @@ from einops import rearrange, repeat from packaging import version from torch import nn + +logger = logging.getLogger(__name__) + + if version.parse(torch.__version__) >= version.parse("2.0.0"): SDP_IS_AVAILABLE = True from torch.backends.cuda import SDPBackend, sdp_kernel @@ -36,9 +41,9 @@ else: SDP_IS_AVAILABLE = False sdp_kernel = nullcontext BACKEND_MAP = {} - print( - f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " - f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." + logger.warning( + f"No SDP backend available, likely because you are running in pytorch versions < 2.0. " + f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading." ) try: @@ -48,7 +53,7 @@ try: XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False - print("no module 'xformers'. Processing without...") + logger.debug("no module 'xformers'. Processing without...") from .diffusionmodules.util import checkpoint @@ -289,7 +294,7 @@ class MemoryEfficientCrossAttention(nn.Module): self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs ): super().__init__() - print( + logger.info( f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"{heads} heads with a dimension of {dim_head}." ) @@ -393,22 +398,21 @@ class BasicTransformerBlock(nn.Module): super().__init__() assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: - print( + logger.warning( f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: - print( + logger.warning( "We do not support vanilla attention anymore, as it is too expensive. Sorry." ) if not XFORMERS_IS_AVAILABLE: - assert ( - False - ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" - else: - print("Falling back to xformers efficient attention.") - attn_mode = "softmax-xformers" + raise NotImplementedError( + "Please install xformers via e.g. 'pip install xformers==0.0.16'" + ) + logger.info("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" attn_cls = self.ATTENTION_MODES[attn_mode] if version.parse(torch.__version__) >= version.parse("2.0.0"): assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) @@ -437,7 +441,7 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint if self.checkpoint: - print(f"{self.__class__.__name__} is using checkpointing") + logger.info(f"{self.__class__.__name__} is using checkpointing") def forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 @@ -554,7 +558,7 @@ class SpatialTransformer(nn.Module): sdp_backend=None, ): super().__init__() - print( + logger.debug( f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" ) from omegaconf import ListConfig @@ -563,8 +567,8 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] if exists(context_dim) and isinstance(context_dim, list): if depth != len(context_dim): - print( - f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + logger.warning( + f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." ) # depth does not match context dims. diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py index 6a3b54f..c4964f7 100644 --- a/sgm/modules/autoencoding/losses/__init__.py +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -1,3 +1,4 @@ +import logging from typing import Any, Union import torch @@ -10,6 +11,9 @@ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from ....util import default, instantiate_from_config +logger = logging.getLogger(__name__) + + def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value @@ -104,7 +108,7 @@ class GeneralLPIPSWithDiscriminator(nn.Module): super().__init__() self.dims = dims if self.dims > 2: - print( + logger.info( f"running with dims={dims}. This means that for perceptual loss calculation, " f"the LPIPS loss will be applied to each frame independently. " ) diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 26efd07..2b24deb 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import logging import math from typing import Any, Callable, Optional @@ -8,6 +9,8 @@ import torch.nn as nn from einops import rearrange from packaging import version +logger = logging.getLogger(__name__) + try: import xformers import xformers.ops @@ -15,7 +18,7 @@ try: XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False - print("no module 'xformers'. Processing without...") + logger.debug("no module 'xformers'. Processing without...") from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention @@ -288,12 +291,14 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" - print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) elif attn_type == "vanilla-xformers": - print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") + logger.debug( + f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." + ) return MemoryEfficientAttnBlock(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels @@ -633,10 +638,8 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - print( - "Working with z of shape {} = {} dimensions.".format( - self.z_shape, np.prod(self.z_shape) - ) + logger.debug( + f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." ) make_attn_cls = self._make_attn() diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index e19b83f..1c87442 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -1,3 +1,4 @@ +import logging import math from abc import abstractmethod from functools import partial @@ -21,6 +22,8 @@ from ...modules.diffusionmodules.util import ( ) from ...util import default, exists +logger = logging.getLogger(__name__) + # dummy replace def convert_module_to_f16(x): @@ -177,13 +180,13 @@ class Downsample(nn.Module): self.dims = dims stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) if use_conv: - print(f"Building a Downsample layer with {dims} dims.") - print( + logger.debug( + f"Building a Downsample layer with {dims} dims.\n" f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " f"kernel-size: 3, stride: {stride}, padding: {padding}" ) if dims == 3: - print(f" --> Downsampling third axis (time): {third_down}") + logger.debug(f" --> Downsampling third axis (time): {third_down}") self.op = conv_nd( dims, self.channels, @@ -270,7 +273,7 @@ class ResBlock(TimestepBlock): 2 * self.out_channels if use_scale_shift_norm else self.out_channels ) if self.skip_t_emb: - print(f"Skipping timestep embedding in {self.__class__.__name__}") + logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}") assert not self.use_scale_shift_norm self.emb_layers = None self.exchange_temb_dims = False @@ -619,12 +622,12 @@ class UNetModel(nn.Module): range(len(num_attention_blocks)), ) ) - print( + logger.warning( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"attention will still not be set." - ) # todo: convert to warning + ) self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -633,7 +636,7 @@ class UNetModel(nn.Module): self.num_classes = num_classes self.use_checkpoint = use_checkpoint if use_fp16: - print("WARNING: use_fp16 was dropped and has no effect anymore.") + logger.warning("use_fp16 was dropped and has no effect anymore.") # self.dtype = th.float16 if use_fp16 else th.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels @@ -664,7 +667,7 @@ class UNetModel(nn.Module): if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") + logger.debug("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = checkpoint_wrapper_fn( diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index ed3f2d2..afb4bb6 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -1,3 +1,4 @@ +import logging from contextlib import nullcontext from functools import partial from typing import Dict, List, Optional, Tuple, Union @@ -32,6 +33,8 @@ from ...util import ( instantiate_from_config, ) +logger = logging.getLogger(__name__) + class AbstractEmbModel(nn.Module): def __init__(self): @@ -96,7 +99,7 @@ class GeneralConditioner(nn.Module): for param in embedder.parameters(): param.requires_grad = False embedder.eval() - print( + logger.debug( f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" ) @@ -727,7 +730,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) - print( + logger.warning( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) @@ -753,7 +756,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel): clip_version, device, max_length=clip_max_length ) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - print( + logger.debug( f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." ) @@ -795,7 +798,7 @@ class SpatialRescaler(nn.Module): self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None or remap_output if self.remap_output: - print( + logger.debug( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) self.channel_mapper = nn.Conv2d( diff --git a/sgm/util.py b/sgm/util.py index c5e68f4..e23616d 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -1,5 +1,6 @@ import functools import importlib +import logging import os from functools import partial from inspect import isfunction @@ -10,6 +11,8 @@ import torch from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors +logger = logging.getLogger(__name__) + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode @@ -86,7 +89,7 @@ def log_txt_as_img(wh, xc, size=10): try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") + logger.warning("Cant encode string %r for logging. Skipping.", lines) txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -161,7 +164,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params @@ -200,11 +203,11 @@ def append_dims(x, target_dims): def load_model_from_config(config, ckpt, verbose=True, freeze=True): - print(f"Loading model from {ckpt}") + logger.info(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") + logger.debug(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) @@ -213,14 +216,13 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): model = instantiate_from_config(config.model) - m, u = model.load_state_dict(sd, strict=False) + missing, unexpected = model.load_state_dict(sd, strict=False) - if len(m) > 0 and verbose: - print("missing keys:") - print(m) - if len(u) > 0 and verbose: - print("unexpected keys:") - print(u) + if verbose: + if missing: + logger.info("missing keys: %r", missing) + if unexpected: + logger.info("unexpected keys: %r", unexpected) if freeze: for param in model.parameters():