mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-24 00:34:20 +01:00
This reverts commit 6f6d3f8716.
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
@@ -9,10 +8,6 @@ from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
SDP_IS_AVAILABLE = True
|
||||
from torch.backends.cuda import SDPBackend, sdp_kernel
|
||||
@@ -41,9 +36,9 @@ else:
|
||||
SDP_IS_AVAILABLE = False
|
||||
sdp_kernel = nullcontext
|
||||
BACKEND_MAP = {}
|
||||
logger.warning(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. "
|
||||
f"In fact, you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
print(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -53,7 +48,7 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
logger.debug("no module 'xformers'. Processing without...")
|
||||
print("no module 'xformers'. Processing without...")
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
|
||||
@@ -294,7 +289,7 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
logger.info(
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads with a dimension of {dim_head}."
|
||||
)
|
||||
@@ -398,21 +393,22 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||
logger.warning(
|
||||
print(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
logger.warning(
|
||||
print(
|
||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
raise NotImplementedError(
|
||||
"Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
)
|
||||
logger.info("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
assert (
|
||||
False
|
||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
assert sdp_backend is None or isinstance(sdp_backend, SDPBackend)
|
||||
@@ -441,7 +437,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
logger.info(f"{self.__class__.__name__} is using checkpointing")
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -558,7 +554,7 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
logger.debug(
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
@@ -567,8 +563,8 @@ class SpatialTransformer(nn.Module):
|
||||
context_dim = [context_dim]
|
||||
if exists(context_dim) and isinstance(context_dim, list):
|
||||
if depth != len(context_dim):
|
||||
logger.warning(
|
||||
f"{self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
print(
|
||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||
)
|
||||
# depth does not match context dims.
|
||||
|
||||
Reference in New Issue
Block a user