mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torchdata.datapipes.iter
|
import torchdata.datapipes.iter
|
||||||
@@ -6,17 +5,16 @@ import webdataset as wds
|
|||||||
from omegaconf import DictConfig
|
from omegaconf import DictConfig
|
||||||
from pytorch_lightning import LightningDataModule
|
from pytorch_lightning import LightningDataModule
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sdata import create_dataset, create_dummy_dataset, create_loader
|
from sdata import create_dataset, create_dummy_dataset, create_loader
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise NotImplementedError(
|
print("#" * 100)
|
||||||
"Datasets not yet available. "
|
print("Datasets not yet available")
|
||||||
"To enable, we need to add stable-datasets as a submodule; "
|
print("to enable, we need to add stable-datasets as a submodule")
|
||||||
"please use ``git submodule update --init --recursive`` "
|
print("please use ``git submodule update --init --recursive``")
|
||||||
"and do ``pip install -e stable-datasets/`` from the root of this repo"
|
print("and do ``pip install -e stable-datasets/`` from the root of this repo")
|
||||||
) from e
|
print("#" * 100)
|
||||||
|
exit(1)
|
||||||
|
|
||||||
|
|
||||||
class StableDataModuleFromConfig(LightningDataModule):
|
class StableDataModuleFromConfig(LightningDataModule):
|
||||||
@@ -41,8 +39,8 @@ class StableDataModuleFromConfig(LightningDataModule):
|
|||||||
"datapipeline" in self.val_config and "loader" in self.val_config
|
"datapipeline" in self.val_config and "loader" in self.val_config
|
||||||
), "validation config requires the fields `datapipeline` and `loader`"
|
), "validation config requires the fields `datapipeline` and `loader`"
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
print(
|
||||||
"No Validation datapipeline defined, using that one from training"
|
"Warning: No Validation datapipeline defined, using that one from training"
|
||||||
)
|
)
|
||||||
self.val_config = train
|
self.val_config = train
|
||||||
|
|
||||||
@@ -54,10 +52,12 @@ class StableDataModuleFromConfig(LightningDataModule):
|
|||||||
|
|
||||||
self.dummy = dummy
|
self.dummy = dummy
|
||||||
if self.dummy:
|
if self.dummy:
|
||||||
logger.warning("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
print("#" * 100)
|
||||||
|
print("USING DUMMY DATASET: HOPE YOU'RE DEBUGGING ;)")
|
||||||
|
print("#" * 100)
|
||||||
|
|
||||||
def setup(self, stage: str) -> None:
|
def setup(self, stage: str) -> None:
|
||||||
logger.debug("Preparing datasets")
|
print("Preparing datasets")
|
||||||
if self.dummy:
|
if self.dummy:
|
||||||
data_fn = create_dummy_dataset
|
data_fn = create_dummy_dataset
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,9 +1,5 @@
|
|||||||
import logging
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class LambdaWarmUpCosineScheduler:
|
class LambdaWarmUpCosineScheduler:
|
||||||
"""
|
"""
|
||||||
@@ -28,8 +24,9 @@ class LambdaWarmUpCosineScheduler:
|
|||||||
self.verbosity_interval = verbosity_interval
|
self.verbosity_interval = verbosity_interval
|
||||||
|
|
||||||
def schedule(self, n, **kwargs):
|
def schedule(self, n, **kwargs):
|
||||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
if self.verbosity_interval > 0:
|
||||||
logger.info(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
if n % self.verbosity_interval == 0:
|
||||||
|
print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
|
||||||
if n < self.lr_warm_up_steps:
|
if n < self.lr_warm_up_steps:
|
||||||
lr = (
|
lr = (
|
||||||
self.lr_max - self.lr_start
|
self.lr_max - self.lr_start
|
||||||
@@ -86,11 +83,12 @@ class LambdaWarmUpCosineScheduler2:
|
|||||||
def schedule(self, n, **kwargs):
|
def schedule(self, n, **kwargs):
|
||||||
cycle = self.find_in_interval(n)
|
cycle = self.find_in_interval(n)
|
||||||
n = n - self.cum_cycles[cycle]
|
n = n - self.cum_cycles[cycle]
|
||||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
if self.verbosity_interval > 0:
|
||||||
logger.info(
|
if n % self.verbosity_interval == 0:
|
||||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
print(
|
||||||
f"current cycle {cycle}"
|
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
)
|
f"current cycle {cycle}"
|
||||||
|
)
|
||||||
if n < self.lr_warm_up_steps[cycle]:
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||||
cycle
|
cycle
|
||||||
@@ -116,11 +114,12 @@ class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
|
|||||||
def schedule(self, n, **kwargs):
|
def schedule(self, n, **kwargs):
|
||||||
cycle = self.find_in_interval(n)
|
cycle = self.find_in_interval(n)
|
||||||
n = n - self.cum_cycles[cycle]
|
n = n - self.cum_cycles[cycle]
|
||||||
if self.verbosity_interval > 0 and n % self.verbosity_interval == 0:
|
if self.verbosity_interval > 0:
|
||||||
logger.info(
|
if n % self.verbosity_interval == 0:
|
||||||
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
print(
|
||||||
f"current cycle {cycle}"
|
f"current step: {n}, recent lr-multiplier: {self.last_f}, "
|
||||||
)
|
f"current cycle {cycle}"
|
||||||
|
)
|
||||||
|
|
||||||
if n < self.lr_warm_up_steps[cycle]:
|
if n < self.lr_warm_up_steps[cycle]:
|
||||||
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import re
|
import re
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
@@ -15,8 +14,6 @@ from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
|||||||
from ..modules.ema import LitEma
|
from ..modules.ema import LitEma
|
||||||
from ..util import default, get_obj_from_str, instantiate_from_config
|
from ..util import default, get_obj_from_str, instantiate_from_config
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractAutoencoder(pl.LightningModule):
|
class AbstractAutoencoder(pl.LightningModule):
|
||||||
"""
|
"""
|
||||||
@@ -41,7 +38,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
|
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self, decay=ema_decay)
|
self.model_ema = LitEma(self, decay=ema_decay)
|
||||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
if ckpt_path is not None:
|
if ckpt_path is not None:
|
||||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||||
@@ -63,16 +60,16 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
for k in keys:
|
for k in keys:
|
||||||
for ik in ignore_keys:
|
for ik in ignore_keys:
|
||||||
if re.match(ik, k):
|
if re.match(ik, k):
|
||||||
logger.debug(f"Deleting key {k} from state_dict.")
|
print("Deleting key {} from state_dict.".format(k))
|
||||||
del sd[k]
|
del sd[k]
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
logger.debug(
|
print(
|
||||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||||
)
|
)
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logger.info(f"Missing Keys: {missing}")
|
print(f"Missing Keys: {missing}")
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logger.info(f"Unexpected Keys: {unexpected}")
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_input(self, batch) -> Any:
|
def get_input(self, batch) -> Any:
|
||||||
@@ -89,14 +86,14 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
self.model_ema.store(self.parameters())
|
self.model_ema.store(self.parameters())
|
||||||
self.model_ema.copy_to(self)
|
self.model_ema.copy_to(self)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logger.info(f"{context}: Switched to EMA weights")
|
print(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.parameters())
|
self.model_ema.restore(self.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logger.info(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||||
@@ -107,7 +104,7 @@ class AbstractAutoencoder(pl.LightningModule):
|
|||||||
raise NotImplementedError("decode()-method of abstract base class called")
|
raise NotImplementedError("decode()-method of abstract base class called")
|
||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
logger.debug(f"loading >>> {cfg['target']} <<< optimizer from config")
|
print(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(
|
||||||
params, lr=lr, **cfg.get("params", dict())
|
params, lr=lr, **cfg.get("params", dict())
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
@@ -19,8 +18,6 @@ from ..util import (
|
|||||||
log_txt_as_img,
|
log_txt_as_img,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DiffusionEngine(pl.LightningModule):
|
class DiffusionEngine(pl.LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -76,7 +73,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
self.use_ema = use_ema
|
self.use_ema = use_ema
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
self.model_ema = LitEma(self.model, decay=ema_decay_rate)
|
||||||
logger.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||||
|
|
||||||
self.scale_factor = scale_factor
|
self.scale_factor = scale_factor
|
||||||
self.disable_first_stage_autocast = disable_first_stage_autocast
|
self.disable_first_stage_autocast = disable_first_stage_autocast
|
||||||
@@ -97,13 +94,13 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||||
logger.info(
|
print(
|
||||||
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys"
|
||||||
)
|
)
|
||||||
if len(missing) > 0:
|
if len(missing) > 0:
|
||||||
logger.info(f"Missing Keys: {missing}")
|
print(f"Missing Keys: {missing}")
|
||||||
if len(unexpected) > 0:
|
if len(unexpected) > 0:
|
||||||
logger.info(f"Unexpected Keys: {unexpected}")
|
print(f"Unexpected Keys: {unexpected}")
|
||||||
|
|
||||||
def _init_first_stage(self, config):
|
def _init_first_stage(self, config):
|
||||||
model = instantiate_from_config(config).eval()
|
model = instantiate_from_config(config).eval()
|
||||||
@@ -182,14 +179,14 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
self.model_ema.store(self.model.parameters())
|
self.model_ema.store(self.model.parameters())
|
||||||
self.model_ema.copy_to(self.model)
|
self.model_ema.copy_to(self.model)
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logger.info(f"{context}: Switched to EMA weights")
|
print(f"{context}: Switched to EMA weights")
|
||||||
try:
|
try:
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
if self.use_ema:
|
if self.use_ema:
|
||||||
self.model_ema.restore(self.model.parameters())
|
self.model_ema.restore(self.model.parameters())
|
||||||
if context is not None:
|
if context is not None:
|
||||||
logger.info(f"{context}: Restored training weights")
|
print(f"{context}: Restored training weights")
|
||||||
|
|
||||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||||
return get_obj_from_str(cfg["target"])(
|
return get_obj_from_str(cfg["target"])(
|
||||||
@@ -205,7 +202,7 @@ class DiffusionEngine(pl.LightningModule):
|
|||||||
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
opt = self.instantiate_optimizer_from_config(params, lr, self.optimizer_config)
|
||||||
if self.scheduler_config is not None:
|
if self.scheduler_config is not None:
|
||||||
scheduler = instantiate_from_config(self.scheduler_config)
|
scheduler = instantiate_from_config(self.scheduler_config)
|
||||||
logger.debug("Setting up LambdaLR scheduler...")
|
print("Setting up LambdaLR scheduler...")
|
||||||
scheduler = [
|
scheduler = [
|
||||||
{
|
{
|
||||||
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule),
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
@@ -9,10 +8,6 @@ from einops import rearrange, repeat
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||||
SDP_IS_AVAILABLE = True
|
SDP_IS_AVAILABLE = True
|
||||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||||
@@ -41,9 +36,9 @@ else:
|
|||||||
SDP_IS_AVAILABLE = False
|
SDP_IS_AVAILABLE = False
|
||||||
sdp_kernel = nullcontext
|
sdp_kernel = nullcontext
|
||||||
BACKEND_MAP = {}
|
BACKEND_MAP = {}
|
||||||
logger.warning(
|
print(
|
||||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
|
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||||
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -53,7 +48,7 @@ try:
|
|||||||
XFORMERS_IS_AVAILABLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
logger.debug("no module 'xformers'. Processing without...")
|
print("no module 'xformers'. Processing without...")
|
||||||
|
|
||||||
from .diffusionmodules.util import checkpoint
|
from .diffusionmodules.util import checkpoint
|
||||||
|
|
||||||
@@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
|||||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.info(
|
print(
|
||||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||||
f"{heads} heads with a dimension of {dim_head}."
|
f"{heads} heads with a dimension of {dim_head}."
|
||||||
)
|
)
|
||||||
@@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
assert attn_mode in self.ATTENTION_MODES
|
assert attn_mode in self.ATTENTION_MODES
|
||||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||||
logger.warning(
|
print(
|
||||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||||
)
|
)
|
||||||
attn_mode = "softmax"
|
attn_mode = "softmax"
|
||||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||||
logger.warning(
|
print(
|
||||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||||
)
|
)
|
||||||
if not XFORMERS_IS_AVAILABLE:
|
if not XFORMERS_IS_AVAILABLE:
|
||||||
raise NotImplementedError(
|
assert (
|
||||||
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
False
|
||||||
)
|
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||||
logger.info("Falling back to xformers efficient attention.")
|
else:
|
||||||
attn_mode = "softmax-xformers"
|
print("Falling back to xformers efficient attention.")
|
||||||
|
attn_mode = "softmax-xformers"
|
||||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||||
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
||||||
@@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
self.norm3 = nn.LayerNorm(dim)
|
self.norm3 = nn.LayerNorm(dim)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
if self.checkpoint:
|
if self.checkpoint:
|
||||||
logger.info(f"{self.__class__.__name__} is using checkpointing")
|
print(f"{self.__class__.__name__} is using checkpointing")
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||||
@@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
sdp_backend=None,
|
sdp_backend=None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
logger.debug(
|
print(
|
||||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||||
)
|
)
|
||||||
from omegaconf import ListConfig
|
from omegaconf import ListConfig
|
||||||
@@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module):
|
|||||||
context_dim = [context_dim]
|
context_dim = [context_dim]
|
||||||
if exists(context_dim) and isinstance(context_dim, list):
|
if exists(context_dim) and isinstance(context_dim, list):
|
||||||
if depth != len(context_dim):
|
if depth != len(context_dim):
|
||||||
logger.warning(
|
print(
|
||||||
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||||
)
|
)
|
||||||
# depth does not match context dims.
|
# depth does not match context dims.
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -11,9 +10,6 @@ from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
|||||||
from ....util import default, instantiate_from_config
|
from ....util import default, instantiate_from_config
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
def adopt_weight(weight, global_step, threshold=0, value=0.0):
|
||||||
if global_step < threshold:
|
if global_step < threshold:
|
||||||
weight = value
|
weight = value
|
||||||
@@ -108,7 +104,7 @@ class GeneralLPIPSWithDiscriminator(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dims = dims
|
self.dims = dims
|
||||||
if self.dims > 2:
|
if self.dims > 2:
|
||||||
logger.info(
|
print(
|
||||||
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
f"running with dims={dims}. This means that for perceptual loss calculation, "
|
||||||
f"the LPIPS loss will be applied to each frame independently. "
|
f"the LPIPS loss will be applied to each frame independently. "
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
# pytorch_diffusion + derived encoder decoder
|
# pytorch_diffusion + derived encoder decoder
|
||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
@@ -9,8 +8,6 @@ import torch.nn as nn
|
|||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import xformers
|
import xformers
|
||||||
import xformers.ops
|
import xformers.ops
|
||||||
@@ -18,7 +15,7 @@ try:
|
|||||||
XFORMERS_IS_AVAILABLE = True
|
XFORMERS_IS_AVAILABLE = True
|
||||||
except:
|
except:
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
logger.debug("no module 'xformers'. Processing without...")
|
print("no module 'xformers'. Processing without...")
|
||||||
|
|
||||||
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
from ...modules.attention import LinearAttention, MemoryEfficientCrossAttention
|
||||||
|
|
||||||
@@ -291,14 +288,12 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
|||||||
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
f"as it is too expensive. Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||||
)
|
)
|
||||||
attn_type = "vanilla-xformers"
|
attn_type = "vanilla-xformers"
|
||||||
logger.debug(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||||
if attn_type == "vanilla":
|
if attn_type == "vanilla":
|
||||||
assert attn_kwargs is None
|
assert attn_kwargs is None
|
||||||
return AttnBlock(in_channels)
|
return AttnBlock(in_channels)
|
||||||
elif attn_type == "vanilla-xformers":
|
elif attn_type == "vanilla-xformers":
|
||||||
logger.debug(
|
print(f"building MemoryEfficientAttnBlock with {in_channels} in_channels...")
|
||||||
f"building MemoryEfficientAttnBlock with {in_channels} in_channels..."
|
|
||||||
)
|
|
||||||
return MemoryEfficientAttnBlock(in_channels)
|
return MemoryEfficientAttnBlock(in_channels)
|
||||||
elif type == "memory-efficient-cross-attn":
|
elif type == "memory-efficient-cross-attn":
|
||||||
attn_kwargs["query_dim"] = in_channels
|
attn_kwargs["query_dim"] = in_channels
|
||||||
@@ -638,8 +633,10 @@ class Decoder(nn.Module):
|
|||||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||||
logger.debug(
|
print(
|
||||||
f"Working with z of shape {self.z_shape} = {np.prod(self.z_shape)} dimensions."
|
"Working with z of shape {} = {} dimensions.".format(
|
||||||
|
self.z_shape, np.prod(self.z_shape)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
make_attn_cls = self._make_attn()
|
make_attn_cls = self._make_attn()
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import math
|
import math
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -22,8 +21,6 @@ from ...modules.diffusionmodules.util import (
|
|||||||
)
|
)
|
||||||
from ...util import default, exists
|
from ...util import default, exists
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# dummy replace
|
# dummy replace
|
||||||
def convert_module_to_f16(x):
|
def convert_module_to_f16(x):
|
||||||
@@ -180,13 +177,13 @@ class Downsample(nn.Module):
|
|||||||
self.dims = dims
|
self.dims = dims
|
||||||
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
stride = 2 if dims != 3 else ((1, 2, 2) if not third_down else (2, 2, 2))
|
||||||
if use_conv:
|
if use_conv:
|
||||||
logger.debug(
|
print(f"Building a Downsample layer with {dims} dims.")
|
||||||
f"Building a Downsample layer with {dims} dims.\n"
|
print(
|
||||||
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
f" --> settings are: \n in-chn: {self.channels}, out-chn: {self.out_channels}, "
|
||||||
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
f"kernel-size: 3, stride: {stride}, padding: {padding}"
|
||||||
)
|
)
|
||||||
if dims == 3:
|
if dims == 3:
|
||||||
logger.debug(f" --> Downsampling third axis (time): {third_down}")
|
print(f" --> Downsampling third axis (time): {third_down}")
|
||||||
self.op = conv_nd(
|
self.op = conv_nd(
|
||||||
dims,
|
dims,
|
||||||
self.channels,
|
self.channels,
|
||||||
@@ -273,7 +270,7 @@ class ResBlock(TimestepBlock):
|
|||||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
||||||
)
|
)
|
||||||
if self.skip_t_emb:
|
if self.skip_t_emb:
|
||||||
logger.debug(f"Skipping timestep embedding in {self.__class__.__name__}")
|
print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||||
assert not self.use_scale_shift_norm
|
assert not self.use_scale_shift_norm
|
||||||
self.emb_layers = None
|
self.emb_layers = None
|
||||||
self.exchange_temb_dims = False
|
self.exchange_temb_dims = False
|
||||||
@@ -622,12 +619,12 @@ class UNetModel(nn.Module):
|
|||||||
range(len(num_attention_blocks)),
|
range(len(num_attention_blocks)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.warning(
|
print(
|
||||||
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||||
f"attention will still not be set."
|
f"attention will still not be set."
|
||||||
)
|
) # todo: convert to warning
|
||||||
|
|
||||||
self.attention_resolutions = attention_resolutions
|
self.attention_resolutions = attention_resolutions
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
@@ -636,7 +633,7 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.use_checkpoint = use_checkpoint
|
self.use_checkpoint = use_checkpoint
|
||||||
if use_fp16:
|
if use_fp16:
|
||||||
logger.warning("use_fp16 was dropped and has no effect anymore.")
|
print("WARNING: use_fp16 was dropped and has no effect anymore.")
|
||||||
# self.dtype = th.float16 if use_fp16 else th.float32
|
# self.dtype = th.float16 if use_fp16 else th.float32
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.num_head_channels = num_head_channels
|
self.num_head_channels = num_head_channels
|
||||||
@@ -667,7 +664,7 @@ class UNetModel(nn.Module):
|
|||||||
if isinstance(self.num_classes, int):
|
if isinstance(self.num_classes, int):
|
||||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||||
elif self.num_classes == "continuous":
|
elif self.num_classes == "continuous":
|
||||||
logger.debug("setting up linear c_adm embedding layer")
|
print("setting up linear c_adm embedding layer")
|
||||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||||
elif self.num_classes == "timestep":
|
elif self.num_classes == "timestep":
|
||||||
self.label_emb = checkpoint_wrapper_fn(
|
self.label_emb = checkpoint_wrapper_fn(
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
@@ -33,8 +32,6 @@ from ...util import (
|
|||||||
instantiate_from_config,
|
instantiate_from_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractEmbModel(nn.Module):
|
class AbstractEmbModel(nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -99,7 +96,7 @@ class GeneralConditioner(nn.Module):
|
|||||||
for param in embedder.parameters():
|
for param in embedder.parameters():
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
embedder.eval()
|
embedder.eval()
|
||||||
logger.debug(
|
print(
|
||||||
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
|
f"Initialized embedder #{n}: {embedder.__class__.__name__} "
|
||||||
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
|
f"with {count_params(embedder, False)} params. Trainable: {embedder.is_trainable}"
|
||||||
)
|
)
|
||||||
@@ -730,7 +727,7 @@ class FrozenOpenCLIPImageEmbedder(AbstractEmbModel):
|
|||||||
)
|
)
|
||||||
if tokens is not None:
|
if tokens is not None:
|
||||||
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
|
tokens = rearrange(tokens, "(b n) t d -> b t (n d)", n=self.max_crops)
|
||||||
logger.warning(
|
print(
|
||||||
f"You are running very experimental token-concat in {self.__class__.__name__}. "
|
f"You are running very experimental token-concat in {self.__class__.__name__}. "
|
||||||
f"Check what you are doing, and then remove this message."
|
f"Check what you are doing, and then remove this message."
|
||||||
)
|
)
|
||||||
@@ -756,7 +753,7 @@ class FrozenCLIPT5Encoder(AbstractEmbModel):
|
|||||||
clip_version, device, max_length=clip_max_length
|
clip_version, device, max_length=clip_max_length
|
||||||
)
|
)
|
||||||
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length)
|
||||||
logger.debug(
|
print(
|
||||||
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
|
||||||
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params."
|
||||||
)
|
)
|
||||||
@@ -798,7 +795,7 @@ class SpatialRescaler(nn.Module):
|
|||||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||||
self.remap_output = out_channels is not None or remap_output
|
self.remap_output = out_channels is not None or remap_output
|
||||||
if self.remap_output:
|
if self.remap_output:
|
||||||
logger.debug(
|
print(
|
||||||
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
f"Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing."
|
||||||
)
|
)
|
||||||
self.channel_mapper = nn.Conv2d(
|
self.channel_mapper = nn.Conv2d(
|
||||||
|
|||||||
24
sgm/util.py
24
sgm/util.py
@@ -1,6 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from inspect import isfunction
|
from inspect import isfunction
|
||||||
@@ -11,8 +10,6 @@ import torch
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
from safetensors.torch import load_file as load_safetensors
|
from safetensors.torch import load_file as load_safetensors
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def disabled_train(self, mode=True):
|
def disabled_train(self, mode=True):
|
||||||
"""Overwrite model.train with this function to make sure train/eval mode
|
"""Overwrite model.train with this function to make sure train/eval mode
|
||||||
@@ -89,7 +86,7 @@ def log_txt_as_img(wh, xc, size=10):
|
|||||||
try:
|
try:
|
||||||
draw.text((0, 0), lines, fill="black", font=font)
|
draw.text((0, 0), lines, fill="black", font=font)
|
||||||
except UnicodeEncodeError:
|
except UnicodeEncodeError:
|
||||||
logger.warning("Cant encode string %r for logging. Skipping.", lines)
|
print("Cant encode string for logging. Skipping.")
|
||||||
|
|
||||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||||
txts.append(txt)
|
txts.append(txt)
|
||||||
@@ -164,7 +161,7 @@ def mean_flat(tensor):
|
|||||||
def count_params(model, verbose=False):
|
def count_params(model, verbose=False):
|
||||||
total_params = sum(p.numel() for p in model.parameters())
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
if verbose:
|
if verbose:
|
||||||
logger.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
|
||||||
return total_params
|
return total_params
|
||||||
|
|
||||||
|
|
||||||
@@ -203,11 +200,11 @@ def append_dims(x, target_dims):
|
|||||||
|
|
||||||
|
|
||||||
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
||||||
logger.info(f"Loading model from {ckpt}")
|
print(f"Loading model from {ckpt}")
|
||||||
if ckpt.endswith("ckpt"):
|
if ckpt.endswith("ckpt"):
|
||||||
pl_sd = torch.load(ckpt, map_location="cpu")
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||||||
if "global_step" in pl_sd:
|
if "global_step" in pl_sd:
|
||||||
logger.debug(f"Global Step: {pl_sd['global_step']}")
|
print(f"Global Step: {pl_sd['global_step']}")
|
||||||
sd = pl_sd["state_dict"]
|
sd = pl_sd["state_dict"]
|
||||||
elif ckpt.endswith("safetensors"):
|
elif ckpt.endswith("safetensors"):
|
||||||
sd = load_safetensors(ckpt)
|
sd = load_safetensors(ckpt)
|
||||||
@@ -216,13 +213,14 @@ def load_model_from_config(config, ckpt, verbose=True, freeze=True):
|
|||||||
|
|
||||||
model = instantiate_from_config(config.model)
|
model = instantiate_from_config(config.model)
|
||||||
|
|
||||||
missing, unexpected = model.load_state_dict(sd, strict=False)
|
m, u = model.load_state_dict(sd, strict=False)
|
||||||
|
|
||||||
if verbose:
|
if len(m) > 0 and verbose:
|
||||||
if missing:
|
print("missing keys:")
|
||||||
logger.info("missing keys: %r", missing)
|
print(m)
|
||||||
if unexpected:
|
if len(u) > 0 and verbose:
|
||||||
logger.info("unexpected keys: %r", unexpected)
|
print("unexpected keys:")
|
||||||
|
print(u)
|
||||||
|
|
||||||
if freeze:
|
if freeze:
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
|
|||||||
Reference in New Issue
Block a user