mirror of
https://github.com/Stability-AI/generative-models.git
synced 2025-12-19 14:24:21 +01:00
320 lines
9.9 KiB
Python
320 lines
9.9 KiB
Python
import einops
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.benchmark as benchmark
|
|
from torch.backends.cuda import SDPBackend
|
|
|
|
from sgm.modules.attention import BasicTransformerBlock, SpatialTransformer
|
|
|
|
|
|
def benchmark_attn():
|
|
# Lets define a helpful benchmarking function:
|
|
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
|
t0 = benchmark.Timer(
|
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
|
)
|
|
return t0.blocked_autorange().mean * 1e6
|
|
|
|
# Lets define the hyper-parameters of our input
|
|
batch_size = 32
|
|
max_sequence_len = 1024
|
|
num_heads = 32
|
|
embed_dimension = 32
|
|
|
|
dtype = torch.float16
|
|
|
|
query = torch.rand(
|
|
batch_size,
|
|
num_heads,
|
|
max_sequence_len,
|
|
embed_dimension,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
key = torch.rand(
|
|
batch_size,
|
|
num_heads,
|
|
max_sequence_len,
|
|
embed_dimension,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
value = torch.rand(
|
|
batch_size,
|
|
num_heads,
|
|
max_sequence_len,
|
|
embed_dimension,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
print(f"q/k/v shape:", query.shape, key.shape, value.shape)
|
|
|
|
# Lets explore the speed of each of the 3 implementations
|
|
from torch.backends.cuda import SDPBackend, sdp_kernel
|
|
|
|
# Helpful arguments mapper
|
|
backend_map = {
|
|
SDPBackend.MATH: {
|
|
"enable_math": True,
|
|
"enable_flash": False,
|
|
"enable_mem_efficient": False,
|
|
},
|
|
SDPBackend.FLASH_ATTENTION: {
|
|
"enable_math": False,
|
|
"enable_flash": True,
|
|
"enable_mem_efficient": False,
|
|
},
|
|
SDPBackend.EFFICIENT_ATTENTION: {
|
|
"enable_math": False,
|
|
"enable_flash": False,
|
|
"enable_mem_efficient": True,
|
|
},
|
|
}
|
|
|
|
from torch.profiler import ProfilerActivity, profile, record_function
|
|
|
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
|
|
|
print(
|
|
f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
)
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("Default detailed stats"):
|
|
for _ in range(25):
|
|
o = F.scaled_dot_product_attention(query, key, value)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
print(
|
|
f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
)
|
|
with sdp_kernel(**backend_map[SDPBackend.MATH]):
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("Math implmentation stats"):
|
|
for _ in range(25):
|
|
o = F.scaled_dot_product_attention(query, key, value)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
|
|
try:
|
|
print(
|
|
f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
)
|
|
except RuntimeError:
|
|
print("FlashAttention is not supported. See warnings for reasons.")
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("FlashAttention stats"):
|
|
for _ in range(25):
|
|
o = F.scaled_dot_product_attention(query, key, value)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
|
|
try:
|
|
print(
|
|
f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds"
|
|
)
|
|
except RuntimeError:
|
|
print("EfficientAttention is not supported. See warnings for reasons.")
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("EfficientAttention stats"):
|
|
for _ in range(25):
|
|
o = F.scaled_dot_product_attention(query, key, value)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
|
|
|
|
def run_model(model, x, context):
|
|
return model(x, context)
|
|
|
|
|
|
def benchmark_transformer_blocks():
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
import torch.utils.benchmark as benchmark
|
|
|
|
def benchmark_torch_function_in_microseconds(f, *args, **kwargs):
|
|
t0 = benchmark.Timer(
|
|
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
|
|
)
|
|
return t0.blocked_autorange().mean * 1e6
|
|
|
|
checkpoint = True
|
|
compile = False
|
|
|
|
batch_size = 32
|
|
h, w = 64, 64
|
|
context_len = 77
|
|
embed_dimension = 1024
|
|
context_dim = 1024
|
|
d_head = 64
|
|
|
|
transformer_depth = 4
|
|
|
|
n_heads = embed_dimension // d_head
|
|
|
|
dtype = torch.float16
|
|
|
|
model_native = SpatialTransformer(
|
|
embed_dimension,
|
|
n_heads,
|
|
d_head,
|
|
context_dim=context_dim,
|
|
use_linear=True,
|
|
use_checkpoint=checkpoint,
|
|
attn_type="softmax",
|
|
depth=transformer_depth,
|
|
sdp_backend=SDPBackend.FLASH_ATTENTION,
|
|
).to(device)
|
|
model_efficient_attn = SpatialTransformer(
|
|
embed_dimension,
|
|
n_heads,
|
|
d_head,
|
|
context_dim=context_dim,
|
|
use_linear=True,
|
|
depth=transformer_depth,
|
|
use_checkpoint=checkpoint,
|
|
attn_type="softmax-xformers",
|
|
).to(device)
|
|
if not checkpoint and compile:
|
|
print("compiling models")
|
|
model_native = torch.compile(model_native)
|
|
model_efficient_attn = torch.compile(model_efficient_attn)
|
|
|
|
x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype)
|
|
c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype)
|
|
|
|
from torch.profiler import ProfilerActivity, profile, record_function
|
|
|
|
activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA]
|
|
|
|
with torch.autocast("cuda"):
|
|
print(
|
|
f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds"
|
|
)
|
|
print(
|
|
f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds"
|
|
)
|
|
|
|
print(75 * "+")
|
|
print("NATIVE")
|
|
print(75 * "+")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("NativeAttention stats"):
|
|
for _ in range(25):
|
|
model_native(x, c)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block")
|
|
|
|
print(75 * "+")
|
|
print("Xformers")
|
|
print(75 * "+")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
with profile(
|
|
activities=activities, record_shapes=False, profile_memory=True
|
|
) as prof:
|
|
with record_function("xformers stats"):
|
|
for _ in range(25):
|
|
model_efficient_attn(x, c)
|
|
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
|
|
print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block")
|
|
|
|
|
|
def test01():
|
|
# conv1x1 vs linear
|
|
from sgm.util import count_params
|
|
|
|
conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda()
|
|
print(count_params(conv))
|
|
linear = torch.nn.Linear(3, 32).cuda()
|
|
print(count_params(linear))
|
|
|
|
print(conv.weight.shape)
|
|
|
|
# use same initialization
|
|
linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1))
|
|
linear.bias = torch.nn.Parameter(conv.bias)
|
|
|
|
print(linear.weight.shape)
|
|
|
|
x = torch.randn(11, 3, 64, 64).cuda()
|
|
|
|
xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous()
|
|
print(xr.shape)
|
|
out_linear = linear(xr)
|
|
print(out_linear.mean(), out_linear.shape)
|
|
|
|
out_conv = conv(x)
|
|
print(out_conv.mean(), out_conv.shape)
|
|
print("done with test01.\n")
|
|
|
|
|
|
def test02():
|
|
# try cosine flash attention
|
|
import time
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cudnn.benchmark = True
|
|
print("testing cosine flash attention...")
|
|
DIM = 1024
|
|
SEQLEN = 4096
|
|
BS = 16
|
|
|
|
print(" softmax (vanilla) first...")
|
|
model = BasicTransformerBlock(
|
|
dim=DIM,
|
|
n_heads=16,
|
|
d_head=64,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
attn_mode="softmax",
|
|
).cuda()
|
|
try:
|
|
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
|
tic = time.time()
|
|
y = model(x)
|
|
toc = time.time()
|
|
print(y.shape, toc - tic)
|
|
except RuntimeError as e:
|
|
# likely oom
|
|
print(str(e))
|
|
|
|
print("\n now flash-cosine...")
|
|
model = BasicTransformerBlock(
|
|
dim=DIM,
|
|
n_heads=16,
|
|
d_head=64,
|
|
dropout=0.0,
|
|
context_dim=None,
|
|
attn_mode="flash-cosine",
|
|
).cuda()
|
|
x = torch.randn(BS, SEQLEN, DIM).cuda()
|
|
tic = time.time()
|
|
y = model(x)
|
|
toc = time.time()
|
|
print(y.shape, toc - tic)
|
|
print("done with test02.\n")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# test01()
|
|
# test02()
|
|
# test03()
|
|
|
|
# benchmark_attn()
|
|
benchmark_transformer_blocks()
|
|
|
|
print("done.")
|