mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torchdata.datapipes.iter
|
||||
@@ -6,17 +5,16 @@ import webdataset as wds
|
||||
from omegaconf import DictConfig
|
||||
from pytorch_lightning import LightningDataModule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from sdata import create_dataset, create_dummy_dataset, create_loader
|
||||
except ImportError as e:
|
||||
raise NotImplementedError(
|
||||
"Datasets not yet available. "
|
||||
"To enable, we need to add stable-datasets as a submodule; "
|
||||
"please use ``git submodule update --init --recursive`` "
|
||||
"and do ``pip install -e stable-datasets/`` from the root of this repo"
|
||||
) from e
|
||||
print("#" * 100)
|
||||
print("Datasets not yet available")
|
||||
print("to enable, we need to add stable-datasets as a submodule")
|
||||
print("please use ``git submodule update --init --recursive``")
|
||||
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
||||
print("#" * 100)
|
||||
exit(1)
|
||||
|
||||
|
||||
class StableDataModuleFromConfig(LightningDataModule):
|
||||
@@ -41,8 +39,8 @@ class StableDataModuleFromConfig(LightningDataModule):
|
||||
"datapipeline" in self.val_config and "loader" in self.val_config
|
||||
), "validation config requires the fields `datapipeline` and `loader`"
|
||||
else:
|
||||
logger.warning(
|
||||
"No Validation datapipeline defined, using that one from training"
|
||||
print(
|
||||
"Warning: No Validation datapipeline defined, using that one from training"
|
||||
)
|
||||
self.val_config = train
|
||||
|
||||
@@ -54,10 +52,12 @@ class StableDataModuleFromConfig(LightningDataModule):
|
||||
|
||||
self.dummy = dummy
|
||||
if self.dummy:
|
||||
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
||||
print("#" * 100)
|
||||
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
||||
print("#" * 100)
|
||||
|
||||
def setup(self, stage: str) -> None:
|
||||
logger.debug("Preparing datasets")
|
||||
print("Preparing datasets")
|
||||
if self.dummy:
|
||||
data_fn = create_dummy_dataset
|
||||
else:
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LambdaWarmUpCosineScheduler:
|
||||
"""
|
||||
@@ -28,8 +24,9 @@ class LambdaWarmUpCosineScheduler:
|
||||
self.verbosity_interval = verbosity_interval
|
||||
|
||||
def schedule(self, n, **kwargs):
|
||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
||||
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||
if n < self.lr_warm_up_steps:
|
||||
lr = (
|
||||
self.lr_max - self.lr_start
|
||||
@@ -86,11 +83,12 @@ class LambdaWarmUpCosineScheduler2:
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
||||
logger.info(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
cycle
|
||||
@@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
||||
def schedule(self, n, **kwargs):
|
||||
cycle = self.find_in_interval(n)
|
||||
n = n - self.cum_cycles[cycle]
|
||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
||||
logger.info(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
if self.verbosity_interval > 0:
|
||||
if n % self.verbosity_interval == 0:
|
||||
print(
|
||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||
f"current cycle {cycle}"
|
||||
)
|
||||
|
||||
if n < self.lr_warm_up_steps[cycle]:
|
||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
@@ -15,8 +14,6 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import default, get_obj_from_str, instantiate_from_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractAutoencoder(pl.LightningModule):
|
||||
"""
|
||||
@@ -41,7 +38,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
@@ -63,16 +60,16 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if re.match(ik, k):
|
||||
logger.debug(f"Deleting key {k} from state_dict.")
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
logger.debug(
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
logger.info(f"Missing Keys: {missing}")
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
logger.info(f"Unexpected Keys: {unexpected}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
@abstractmethod
|
||||
def get_input(self, batch) -> Any:
|
||||
@@ -89,14 +86,14 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Switched to EMA weights")
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Restored training weights")
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
@@ -107,7 +104,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
@@ -19,8 +18,6 @@ from ..util import (
|
||||
log_txt_as_img,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DiffusionEngine(pl.LightningModule):
|
||||
def __init__(
|
||||
@@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
self.scale_factor = scale_factor
|
||||
self.disable_first_stage_autocast = disable_first_stage_autocast
|
||||
@@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule):
|
||||
raise NotImplementedError
|
||||
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
logger.info(
|
||||
print(
|
||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||
)
|
||||
if len(missing) > 0:
|
||||
logger.info(f"Missing Keys: {missing}")
|
||||
print(f"Missing Keys: {missing}")
|
||||
if len(unexpected) > 0:
|
||||
logger.info(f"Unexpected Keys: {unexpected}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def _init_first_stage(self, config):
|
||||
model = instantiate_from_config(config).eval()
|
||||
@@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule):
|
||||
self.model_ema.store(self.model.parameters())
|
||||
self.model_ema.copy_to(self.model)
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Switched to EMA weights")
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.model.parameters())
|
||||
if context is not None:
|
||||
logger.info(f"{context}: Restored training weights")
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
@@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule):
|
||||
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
logger.debug("Setting up LambdaLR scheduler...")
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
@@ -9,10 +8,6 @@ from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
SDP_IS_AVAILABLE = True
|
||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||
@@ -41,9 +36,9 @@ else:
|
||||
SDP_IS_AVAILABLE = False
|
||||
sdp_kernel = nullcontext
|
||||
BACKEND_MAP = {}
|
||||
logger.warning(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
|
||||
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
print(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -53,7 +48,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
logger.debug("no module 'xformers'. Processing without...")
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
@@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
logger.info(
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads with a dimension of {dim_head}."
|
||||
)
|
||||
@@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||
logger.warning(
|
||||
print(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
logger.warning(
|
||||
print(
|
||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
logger.info("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
assert (
|
||||
False
|
||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
||||
@@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
logger.info(f"{self.__class__.__name__} is using checkpointing")
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
logger.debug(
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
@@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module):
|
||||
context_dim = [context_dim]
|
||||
if exists(context_dim) and isinstance(context_dim, list):
|
||||
if depth != len(context_dim):
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
print(
|
||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||
)
|
||||
# depth does not match context dims.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
import torch
|
||||
@@ -11,9 +10,6 @@ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
||||
from ....util import default, instantiate_from_config
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||
if global_step < threshold:
|
||||
weight = value
|
||||
@@ -108,7 +104,7 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
if self.dims > 2:
|
||||
logger.info(
|
||||
print(
|
||||
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
||||
f"the LPIPS loss will be applied to each frame independently. "
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
@@ -9,8 +8,6 @@ import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from packaging import version
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
@@ -18,7 +15,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
logger.debug("no module 'xformers'. Processing without...")
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||
|
||||
@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
attn_type = "vanilla-xformers"
|
||||
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "vanilla-xformers":
|
||||
logger.debug(
|
||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
||||
)
|
||||
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == "memory-efficient-cross-attn":
|
||||
attn_kwargs["query_dim"] = in_channels
|
||||
@@ -638,8 +633,10 @@ class Decoder(nn.Module):
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logger.debug(
|
||||
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
|
||||
print(
|
||||
"Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)
|
||||
)
|
||||
)
|
||||
|
||||
make_attn_cls = self._make_attn()
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
|
||||
)
|
||||
from ...util import default, exists
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
||||
if use_conv:
|
||||
logger.debug(
|
||||
f"Building a Downsample layer with {dims} dims.\n"
|
||||
print(f"Building a Downsample layer with {dims} dims.")
|
||||
print(
|
||||
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
||||
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
||||
)
|
||||
if dims == 3:
|
||||
logger.debug(f" --> Downsampling third axis (time): {third_down}")
|
||||
print(f" --> Downsampling third axis (time): {third_down}")
|
||||
self.op = conv_nd(
|
||||
dims,
|
||||
self.channels,
|
||||
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
||||
)
|
||||
if self.skip_t_emb:
|
||||
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
assert not self.use_scale_shift_norm
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
|
||||
range(len(num_attention_blocks)),
|
||||
)
|
||||
)
|
||||
logger.warning(
|
||||
print(
|
||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set."
|
||||
)
|
||||
) # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
if use_fp16:
|
||||
logger.warning("use_fp16 was dropped and has no effect anymore.")
|
||||
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
||||
# self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
logger.debug("setting up linear c_adm embedding layer")
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "timestep":
|
||||
self.label_emb = checkpoint_wrapper_fn(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -33,8 +32,6 @@ from ...util import (
|
||||
instantiate_from_config,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractEmbModel(nn.Module):
|
||||
def __init__(self):
|
||||
@@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module):
|
||||
for param in embedder.parameters():
|
||||
param.requires_grad = False
|
||||
embedder.eval()
|
||||
logger.debug(
|
||||
print(
|
||||
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
|
||||
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
|
||||
)
|
||||
@@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
|
||||
)
|
||||
if tokens is not None:
|
||||
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
|
||||
logger.warning(
|
||||
print(
|
||||
f"You are running very experimental token-concat in {self.__class__.__name__}. "
|
||||
f"Check what you are doing, and then remove this message."
|
||||
)
|
||||
@@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
|
||||
clip_version, device, max_length=clip_max_length
|
||||
)
|
||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||
logger.debug(
|
||||
print(
|
||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||
)
|
||||
@@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module):
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None or remap_output
|
||||
if self.remap_output:
|
||||
logger.debug(
|
||||
print(
|
||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||
)
|
||||
self.channel_mapper = nn.Conv2d(
|
||||
|
||||
24
sgm/util.py
24
sgm/util.py
@@ -1,6 +1,5 @@
|
||||
import functools
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
@@ -11,8 +10,6 @@ import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from safetensors.torch import load_file as load_safetensors
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
logger.warning("Cant encode string %r for logging. Skipping.", lines)
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@@ -164,7 +161,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
@@ -203,11 +200,11 @@ def append_dims(x, target_dims):
|
||||
|
||||
|
||||
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
logger.info(f"Loading model from {ckpt}")
|
||||
print(f"Loading model from {ckpt}")
|
||||
if ckpt.endswith("ckpt"):
|
||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||
if "global_step" in pl_sd:
|
||||
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
||||
print(f"Global Step: {pl_sd['global_step']}")
|
||||
sd = pl_sd["state_dict"]
|
||||
elif ckpt.endswith("safetensors"):
|
||||
sd = load_safetensors(ckpt)
|
||||
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
||||
m, u = model.load_state_dict(sd, strict=False)
|
||||
|
||||
if verbose:
|
||||
if missing:
|
||||
logger.info("missing keys: %r", missing)
|
||||
if unexpected:
|
||||
logger.info("unexpected keys: %r", unexpected)
|
||||
if len(m) > 0 and verbose:
|
||||
print("missing keys:")
|
||||
print(m)
|
||||
if len(u) > 0 and verbose:
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
if freeze:
|
||||
for param in model.parameters():
|
||||
|
||||
Reference in New Issue
Block a user