diff --git a/sgm/data/dataset.py b/sgm/data/dataset.py index ffa9ab8..b726149 100644 --- a/sgm/data/dataset.py +++ b/sgm/data/dataset.py @@ -1,4 +1,3 @@ -import logging from typing import Optional import torchdata.datapipes.iter @@ -6,17 +5,16 @@ import webdataset as wds from omegaconf import DictConfig from pytorch_lightning import LightningDataModule -logger = logging.getLogger(__name__) - try: from sdata import create_dataset, create_dummy_dataset, create_loader except ImportError as e: - raise NotImplementedError( - "Datasets not yet available. " - "To enable, we need to add stable-datasets as a submodule; " - "please use ``git submodule update --init --recursive`` " - "and do ``pip install -e stable-datasets/`` from the root of this repo" - ) from e + print("#" * 100) + print("Datasets not yet available") + print("to enable, we need to add stable-datasets as a submodule") + print("please use ``git submodule update --init --recursive``") + print("and do ``pip install -e stable-datasets/`` from the root of this repo") + print("#" * 100) + exit(1) class StableDataModuleFromConfig(LightningDataModule): @@ -41,8 +39,8 @@ class StableDataModuleFromConfig(LightningDataModule): "datapipeline" in self.val_config and "loader" in self.val_config ), "validation config requires the fields `datapipeline` and `loader`" else: - logger.warning( - "No Validation datapipeline defined, using that one from training" + print( + "Warning: No Validation datapipeline defined, using that one from training" ) self.val_config = train @@ -54,10 +52,12 @@ class StableDataModuleFromConfig(LightningDataModule): self.dummy = dummy if self.dummy: - logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") + print("#" * 100) + print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)") + print("#" * 100) def setup(self, stage: str) -> None: - logger.debug("Preparing datasets") + print("Preparing datasets") if self.dummy: data_fn = create_dummy_dataset else: diff --git a/sgm/lr_scheduler.py b/sgm/lr_scheduler.py index 33a0ac4..b2f4d38 100644 --- a/sgm/lr_scheduler.py +++ b/sgm/lr_scheduler.py @@ -1,9 +1,5 @@ -import logging - import numpy as np -logger = logging.getLogger(__name__) - class LambdaWarmUpCosineScheduler: """ @@ -28,8 +24,9 @@ class LambdaWarmUpCosineScheduler: self.verbosity_interval = verbosity_interval def schedule(self, n, **kwargs): - if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: - logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") if n < self.lr_warm_up_steps: lr = ( self.lr_max - self.lr_start @@ -86,11 +83,12 @@ class LambdaWarmUpCosineScheduler2: def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: - logger.info( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ cycle @@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): def schedule(self, n, **kwargs): cycle = self.find_in_interval(n) n = n - self.cum_cycles[cycle] - if self.verbosity_interval > 0 and n % self.verbosity_interval == 0: - logger.info( - f"current step: {n}, recent lr-multiplier: {self.last_f}, " - f"current cycle {cycle}" - ) + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print( + f"current step: {n}, recent lr-multiplier: {self.last_f}, " + f"current cycle {cycle}" + ) if n < self.lr_warm_up_steps[cycle]: f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[ diff --git a/sgm/models/autoencoder.py b/sgm/models/autoencoder.py index c7114af..78fb551 100644 --- a/sgm/models/autoencoder.py +++ b/sgm/models/autoencoder.py @@ -1,4 +1,3 @@ -import logging import re from abc import abstractmethod from contextlib import contextmanager @@ -15,8 +14,6 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution from ..modules.ema import LitEma from ..util import default, get_obj_from_str, instantiate_from_config -logger = logging.getLogger(__name__) - class AbstractAutoencoder(pl.LightningModule): """ @@ -41,7 +38,7 @@ class AbstractAutoencoder(pl.LightningModule): if self.use_ema: self.model_ema = LitEma(self, decay=ema_decay) - logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") if ckpt_path is not None: self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) @@ -63,16 +60,16 @@ class AbstractAutoencoder(pl.LightningModule): for k in keys: for ik in ignore_keys: if re.match(ik, k): - logger.debug(f"Deleting key {k} from state_dict.") + print("Deleting key {} from state_dict.".format(k)) del sd[k] missing, unexpected = self.load_state_dict(sd, strict=False) - logger.debug( + print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: - logger.info(f"Missing Keys: {missing}") + print(f"Missing Keys: {missing}") if len(unexpected) > 0: - logger.info(f"Unexpected Keys: {unexpected}") + print(f"Unexpected Keys: {unexpected}") @abstractmethod def get_input(self, batch) -> Any: @@ -89,14 +86,14 @@ class AbstractAutoencoder(pl.LightningModule): self.model_ema.store(self.parameters()) self.model_ema.copy_to(self) if context is not None: - logger.info(f"{context}: Switched to EMA weights") + print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.parameters()) if context is not None: - logger.info(f"{context}: Restored training weights") + print(f"{context}: Restored training weights") @abstractmethod def encode(self, *args, **kwargs) -> torch.Tensor: @@ -107,7 +104,7 @@ class AbstractAutoencoder(pl.LightningModule): raise NotImplementedError("decode()-method of abstract base class called") def instantiate_optimizer_from_config(self, params, lr, cfg): - logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config") + print(f"loading >>> {cfg['target']} <<< optimizer from config") return get_obj_from_str(cfg["target"])( params, lr=lr, **cfg.get("params", dict()) ) diff --git a/sgm/models/diffusion.py b/sgm/models/diffusion.py index 3ffb146..e1f1397 100644 --- a/sgm/models/diffusion.py +++ b/sgm/models/diffusion.py @@ -1,4 +1,3 @@ -import logging from contextlib import contextmanager from typing import Any, Dict, List, Tuple, Union @@ -19,8 +18,6 @@ from ..util import ( log_txt_as_img, ) -logger = logging.getLogger(__name__) - class DiffusionEngine(pl.LightningModule): def __init__( @@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule): self.use_ema = use_ema if self.use_ema: self.model_ema = LitEma(self.model, decay=ema_decay_rate) - logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") self.scale_factor = scale_factor self.disable_first_stage_autocast = disable_first_stage_autocast @@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule): raise NotImplementedError missing, unexpected = self.load_state_dict(sd, strict=False) - logger.info( + print( f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys" ) if len(missing) > 0: - logger.info(f"Missing Keys: {missing}") + print(f"Missing Keys: {missing}") if len(unexpected) > 0: - logger.info(f"Unexpected Keys: {unexpected}") + print(f"Unexpected Keys: {unexpected}") def _init_first_stage(self, config): model = instantiate_from_config(config).eval() @@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule): self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: - logger.info(f"{context}: Switched to EMA weights") + print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: - logger.info(f"{context}: Restored training weights") + print(f"{context}: Restored training weights") def instantiate_optimizer_from_config(self, params, lr, cfg): return get_obj_from_str(cfg["target"])( @@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule): opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config) if self.scheduler_config is not None: scheduler = instantiate_from_config(self.scheduler_config) - logger.debug("Setting up LambdaLR scheduler...") + print("Setting up LambdaLR scheduler...") scheduler = [ { "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 1de2918..f813be2 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -1,4 +1,3 @@ -import logging import math from inspect import isfunction from typing import Any, Optional @@ -9,10 +8,6 @@ from einops import rearrange, repeat from packaging import version from torch import nn - -logger = logging.getLogger(__name__) - - if version.parse(torch.__version__) >= version.parse("2.0.0"): SDP_IS_AVAILABLE = True from torch.backends.cuda import SDPBackend, sdp_kernel @@ -41,9 +36,9 @@ else: SDP_IS_AVAILABLE = False sdp_kernel = nullcontext BACKEND_MAP = {} - logger.warning( - f"No SDP backend available, likely because you are running in pytorch versions < 2.0. " - f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading." + print( + f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, " + f"you are using PyTorch {torch.__version__}. You might want to consider upgrading." ) try: @@ -53,7 +48,7 @@ try: XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False - logger.debug("no module 'xformers'. Processing without...") + print("no module 'xformers'. Processing without...") from .diffusionmodules.util import checkpoint @@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module): self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs ): super().__init__() - logger.info( + print( f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using " f"{heads} heads with a dimension of {dim_head}." ) @@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module): super().__init__() assert attn_mode in self.ATTENTION_MODES if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE: - logger.warning( + print( f"Attention mode '{attn_mode}' is not available. Falling back to native attention. " f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}" ) attn_mode = "softmax" elif attn_mode == "softmax" and not SDP_IS_AVAILABLE: - logger.warning( + print( "We do not support vanilla attention anymore, as it is too expensive. Sorry." ) if not XFORMERS_IS_AVAILABLE: - raise NotImplementedError( - "Please install xformers via e.g. 'pip install xformers==0.0.16'" - ) - logger.info("Falling back to xformers efficient attention.") - attn_mode = "softmax-xformers" + assert ( + False + ), "Please install xformers via e.g. 'pip install xformers==0.0.16'" + else: + print("Falling back to xformers efficient attention.") + attn_mode = "softmax-xformers" attn_cls = self.ATTENTION_MODES[attn_mode] if version.parse(torch.__version__) >= version.parse("2.0.0"): assert sdp_backend is None or isinstance(sdp_backend, SDPBackend) @@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) self.checkpoint = checkpoint if self.checkpoint: - logger.info(f"{self.__class__.__name__} is using checkpointing") + print(f"{self.__class__.__name__} is using checkpointing") def forward( self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0 @@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module): sdp_backend=None, ): super().__init__() - logger.debug( + print( f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads" ) from omegaconf import ListConfig @@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module): context_dim = [context_dim] if exists(context_dim) and isinstance(context_dim, list): if depth != len(context_dim): - logger.warning( - f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " + print( + f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, " f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now." ) # depth does not match context dims. diff --git a/sgm/modules/autoencoding/losses/__init__.py b/sgm/modules/autoencoding/losses/__init__.py index c4964f7..6a3b54f 100644 --- a/sgm/modules/autoencoding/losses/__init__.py +++ b/sgm/modules/autoencoding/losses/__init__.py @@ -1,4 +1,3 @@ -import logging from typing import Any, Union import torch @@ -11,9 +10,6 @@ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss from ....util import default, instantiate_from_config -logger = logging.getLogger(__name__) - - def adopt_weight(weight, global_step, threshold=0, value=0.0): if global_step < threshold: weight = value @@ -108,7 +104,7 @@ class GeneralLPIPSWithDiscriminator(nn.Module): super().__init__() self.dims = dims if self.dims > 2: - logger.info( + print( f"running with dims={dims}. This means that for perceptual loss calculation, " f"the LPIPS loss will be applied to each frame independently. " ) diff --git a/sgm/modules/diffusionmodules/model.py b/sgm/modules/diffusionmodules/model.py index 2b24deb..26efd07 100644 --- a/sgm/modules/diffusionmodules/model.py +++ b/sgm/modules/diffusionmodules/model.py @@ -1,5 +1,4 @@ # pytorch_diffusion + derived encoder decoder -import logging import math from typing import Any, Callable, Optional @@ -9,8 +8,6 @@ import torch.nn as nn from einops import rearrange from packaging import version -logger = logging.getLogger(__name__) - try: import xformers import xformers.ops @@ -18,7 +15,7 @@ try: XFORMERS_IS_AVAILABLE = True except: XFORMERS_IS_AVAILABLE = False - logger.debug("no module 'xformers'. Processing without...") + print("no module 'xformers'. Processing without...") from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention @@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'" ) attn_type = "vanilla-xformers" - logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels") + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": assert attn_kwargs is None return AttnBlock(in_channels) elif attn_type == "vanilla-xformers": - logger.debug( - f"building MemoryEfficientAttnBlock with {in_channels} in_channels..." - ) + print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...") return MemoryEfficientAttnBlock(in_channels) elif type == "memory-efficient-cross-attn": attn_kwargs["query_dim"] = in_channels @@ -638,8 +633,10 @@ class Decoder(nn.Module): block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) - logger.debug( - f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions." + print( + "Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape) + ) ) make_attn_cls = self._make_attn() diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index 1c87442..e19b83f 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -1,4 +1,3 @@ -import logging import math from abc import abstractmethod from functools import partial @@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import ( ) from ...util import default, exists -logger = logging.getLogger(__name__) - # dummy replace def convert_module_to_f16(x): @@ -180,13 +177,13 @@ class Downsample(nn.Module): self.dims = dims stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2)) if use_conv: - logger.debug( - f"Building a Downsample layer with {dims} dims.\n" + print(f"Building a Downsample layer with {dims} dims.") + print( f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, " f"kernel-size: 3, stride: {stride}, padding: {padding}" ) if dims == 3: - logger.debug(f" --> Downsampling third axis (time): {third_down}") + print(f" --> Downsampling third axis (time): {third_down}") self.op = conv_nd( dims, self.channels, @@ -273,7 +270,7 @@ class ResBlock(TimestepBlock): 2 * self.out_channels if use_scale_shift_norm else self.out_channels ) if self.skip_t_emb: - logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}") + print(f"Skipping timestep embedding in {self.__class__.__name__}") assert not self.use_scale_shift_norm self.emb_layers = None self.exchange_temb_dims = False @@ -622,12 +619,12 @@ class UNetModel(nn.Module): range(len(num_attention_blocks)), ) ) - logger.warning( + print( f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " f"This option has LESS priority than attention_resolutions {attention_resolutions}, " f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " f"attention will still not be set." - ) + ) # todo: convert to warning self.attention_resolutions = attention_resolutions self.dropout = dropout @@ -636,7 +633,7 @@ class UNetModel(nn.Module): self.num_classes = num_classes self.use_checkpoint = use_checkpoint if use_fp16: - logger.warning("use_fp16 was dropped and has no effect anymore.") + print("WARNING: use_fp16 was dropped and has no effect anymore.") # self.dtype = th.float16 if use_fp16 else th.float32 self.num_heads = num_heads self.num_head_channels = num_head_channels @@ -667,7 +664,7 @@ class UNetModel(nn.Module): if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": - logger.debug("setting up linear c_adm embedding layer") + print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "timestep": self.label_emb = checkpoint_wrapper_fn( diff --git a/sgm/modules/encoders/modules.py b/sgm/modules/encoders/modules.py index afb4bb6..ed3f2d2 100644 --- a/sgm/modules/encoders/modules.py +++ b/sgm/modules/encoders/modules.py @@ -1,4 +1,3 @@ -import logging from contextlib import nullcontext from functools import partial from typing import Dict, List, Optional, Tuple, Union @@ -33,8 +32,6 @@ from ...util import ( instantiate_from_config, ) -logger = logging.getLogger(__name__) - class AbstractEmbModel(nn.Module): def __init__(self): @@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module): for param in embedder.parameters(): param.requires_grad = False embedder.eval() - logger.debug( + print( f"Initialized embedder #{n}: {embedder.__class__.__name__} " f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}" ) @@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel): ) if tokens is not None: tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops) - logger.warning( + print( f"You are running very experimental token-concat in {self.__class__.__name__}. " f"Check what you are doing, and then remove this message." ) @@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel): clip_version, device, max_length=clip_max_length ) self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) - logger.debug( + print( f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, " f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params." ) @@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module): self.interpolator = partial(torch.nn.functional.interpolate, mode=method) self.remap_output = out_channels is not None or remap_output if self.remap_output: - logger.debug( + print( f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing." ) self.channel_mapper = nn.Conv2d( diff --git a/sgm/util.py b/sgm/util.py index e23616d..c5e68f4 100644 --- a/sgm/util.py +++ b/sgm/util.py @@ -1,6 +1,5 @@ import functools import importlib -import logging import os from functools import partial from inspect import isfunction @@ -11,8 +10,6 @@ import torch from PIL import Image, ImageDraw, ImageFont from safetensors.torch import load_file as load_safetensors -logger = logging.getLogger(__name__) - def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode @@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10): try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - logger.warning("Cant encode string %r for logging. Skipping.", lines) + print("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -164,7 +161,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") return total_params @@ -203,11 +200,11 @@ def append_dims(x, target_dims): def load_model_from_config(config, ckpt, verbose=True, freeze=True): - logger.info(f"Loading model from {ckpt}") + print(f"Loading model from {ckpt}") if ckpt.endswith("ckpt"): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd: - logger.debug(f"Global Step: {pl_sd['global_step']}") + print(f"Global Step: {pl_sd['global_step']}") sd = pl_sd["state_dict"] elif ckpt.endswith("safetensors"): sd = load_safetensors(ckpt) @@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True): model = instantiate_from_config(config.model) - missing, unexpected = model.load_state_dict(sd, strict=False) + m, u = model.load_state_dict(sd, strict=False) - if verbose: - if missing: - logger.info("missing keys: %r", missing) - if unexpected: - logger.info("unexpected keys: %r", unexpected) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) if freeze: for param in model.parameters():