mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-21 15:24:22 +01:00
Stable Video Diffusion
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import math
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
@@ -7,6 +8,9 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
logpy = logging.getLogger(__name__)
|
||||
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
SDP_IS_AVAILABLE = True
|
||||
@@ -36,9 +40,10 @@ else:
|
||||
SDP_IS_AVAILABLE = False
|
||||
sdp_kernel = nullcontext
|
||||
BACKEND_MAP = {}
|
||||
print(
|
||||
f"No SDP backend available, likely because you are running in pytorch versions < 2.0. In fact, "
|
||||
f"you are using PyTorch {torch.__version__}. You might want to consider upgrading."
|
||||
logpy.warn(
|
||||
f"No SDP backend available, likely because you are running in pytorch "
|
||||
f"versions < 2.0. In fact, you are using PyTorch {torch.__version__}. "
|
||||
f"You might want to consider upgrading."
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -48,9 +53,9 @@ try:
|
||||
XFORMERS_IS_AVAILABLE = True
|
||||
except:
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
print("no module 'xformers'. Processing without...")
|
||||
logpy.warn("no module 'xformers'. Processing without...")
|
||||
|
||||
from .diffusionmodules.util import checkpoint
|
||||
# from .diffusionmodules.util import mixed_checkpoint as checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
@@ -146,6 +151,62 @@ class LinearAttention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
ATTENTION_MODES = ("xformers", "torch", "math")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_scale: Optional[float] = None,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
attn_mode: str = "xformers",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
self.attn_mode = attn_mode
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
|
||||
qkv = self.qkv(x)
|
||||
if self.attn_mode == "torch":
|
||||
qkv = rearrange(
|
||||
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
|
||||
).float()
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
||||
x = rearrange(x, "B H L D -> B L (H D)")
|
||||
elif self.attn_mode == "xformers":
|
||||
qkv = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
|
||||
x = xformers.ops.memory_efficient_attention(q, k, v)
|
||||
x = rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
|
||||
elif self.attn_mode == "math":
|
||||
qkv = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
|
||||
else:
|
||||
raise NotImplemented
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
@@ -289,9 +350,10 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
print(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads with a dimension of {dim_head}."
|
||||
logpy.debug(
|
||||
f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, "
|
||||
f"context_dim is {context_dim} and using {heads} heads with a "
|
||||
f"dimension of {dim_head}."
|
||||
)
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
@@ -352,9 +414,29 @@ class MemoryEfficientCrossAttention(nn.Module):
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
if version.parse(xformers.__version__) >= version.parse("0.0.21"):
|
||||
# NOTE: workaround for
|
||||
# https://github.com/facebookresearch/xformers/issues/845
|
||||
max_bs = 32768
|
||||
N = q.shape[0]
|
||||
n_batches = math.ceil(N / max_bs)
|
||||
out = list()
|
||||
for i_batch in range(n_batches):
|
||||
batch = slice(i_batch * max_bs, (i_batch + 1) * max_bs)
|
||||
out.append(
|
||||
xformers.ops.memory_efficient_attention(
|
||||
q[batch],
|
||||
k[batch],
|
||||
v[batch],
|
||||
attn_bias=None,
|
||||
op=self.attention_op,
|
||||
)
|
||||
)
|
||||
out = torch.cat(out, 0)
|
||||
else:
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op
|
||||
)
|
||||
|
||||
# TODO: Use this directly in the attention operation, as a bias
|
||||
if exists(mask):
|
||||
@@ -393,21 +475,24 @@ class BasicTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
if attn_mode != "softmax" and not XFORMERS_IS_AVAILABLE:
|
||||
print(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling back to native attention. "
|
||||
f"This is not a problem in Pytorch >= 2.0. FYI, you are running with PyTorch version {torch.__version__}"
|
||||
logpy.warn(
|
||||
f"Attention mode '{attn_mode}' is not available. Falling "
|
||||
f"back to native attention. This is not a problem in "
|
||||
f"Pytorch >= 2.0. FYI, you are running with PyTorch "
|
||||
f"version {torch.__version__}."
|
||||
)
|
||||
attn_mode = "softmax"
|
||||
elif attn_mode == "softmax" and not SDP_IS_AVAILABLE:
|
||||
print(
|
||||
"We do not support vanilla attention anymore, as it is too expensive. Sorry."
|
||||
logpy.warn(
|
||||
"We do not support vanilla attention anymore, as it is too "
|
||||
"expensive. Sorry."
|
||||
)
|
||||
if not XFORMERS_IS_AVAILABLE:
|
||||
assert (
|
||||
False
|
||||
), "Please install xformers via e.g. 'pip install xformers==0.0.16'"
|
||||
else:
|
||||
print("Falling back to xformers efficient attention.")
|
||||
logpy.info("Falling back to xformers efficient attention.")
|
||||
attn_mode = "softmax-xformers"
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
if version.parse(torch.__version__) >= version.parse("2.0.0"):
|
||||
@@ -437,7 +522,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
if self.checkpoint:
|
||||
print(f"{self.__class__.__name__} is using checkpointing")
|
||||
logpy.debug(f"{self.__class__.__name__} is using checkpointing")
|
||||
|
||||
def forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -456,9 +541,12 @@ class BasicTransformerBlock(nn.Module):
|
||||
)
|
||||
|
||||
# return mixed_checkpoint(self._forward, kwargs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
if self.checkpoint:
|
||||
# inputs = {"x": x, "context": context}
|
||||
return checkpoint(self._forward, x, context)
|
||||
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||
else:
|
||||
return self._forward(**kwargs)
|
||||
|
||||
def _forward(
|
||||
self, x, context=None, additional_tokens=None, n_times_crossframe_attn_in_self=0
|
||||
@@ -518,9 +606,9 @@ class BasicTransformerSingleLayerBlock(nn.Module):
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(
|
||||
self._forward, (x, context), self.parameters(), self.checkpoint
|
||||
)
|
||||
# inputs = {"x": x, "context": context}
|
||||
# return checkpoint(self._forward, inputs, self.parameters(), self.checkpoint)
|
||||
return checkpoint(self._forward, x, context)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context) + x
|
||||
@@ -554,18 +642,20 @@ class SpatialTransformer(nn.Module):
|
||||
sdp_backend=None,
|
||||
):
|
||||
super().__init__()
|
||||
print(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ {in_channels} channels and {n_heads} heads"
|
||||
logpy.debug(
|
||||
f"constructing {self.__class__.__name__} of depth {depth} w/ "
|
||||
f"{in_channels} channels and {n_heads} heads."
|
||||
)
|
||||
from omegaconf import ListConfig
|
||||
|
||||
if exists(context_dim) and not isinstance(context_dim, (list, ListConfig)):
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
if exists(context_dim) and isinstance(context_dim, list):
|
||||
if depth != len(context_dim):
|
||||
print(
|
||||
f"WARNING: {self.__class__.__name__}: Found context dims {context_dim} of depth {len(context_dim)}, "
|
||||
f"which does not match the specified 'depth' of {depth}. Setting context_dim to {depth * [context_dim[0]]} now."
|
||||
logpy.warn(
|
||||
f"{self.__class__.__name__}: Found context dims "
|
||||
f"{context_dim} of depth {len(context_dim)}, which does not "
|
||||
f"match the specified 'depth' of {depth}. Setting context_dim "
|
||||
f"to {depth * [context_dim[0]]} now."
|
||||
)
|
||||
# depth does not match context dims.
|
||||
assert all(
|
||||
@@ -631,3 +721,39 @@ class SpatialTransformer(nn.Module):
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
|
||||
|
||||
class SimpleTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
depth: int,
|
||||
heads: int,
|
||||
dim_head: int,
|
||||
context_dim: Optional[int] = None,
|
||||
dropout: float = 0.0,
|
||||
checkpoint: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList([])
|
||||
for _ in range(depth):
|
||||
self.layers.append(
|
||||
BasicTransformerBlock(
|
||||
dim,
|
||||
heads,
|
||||
dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim,
|
||||
attn_mode="softmax-xformers",
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
context: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, context)
|
||||
return x
|
||||
|
||||
Reference in New Issue
Block a user