diff --git a/scripts/tests/attention.py b/scripts/tests/attention.py new file mode 100644 index 0000000..4cd5543 --- /dev/null +++ b/scripts/tests/attention.py @@ -0,0 +1,319 @@ +import torch +import einops +from torch.backends.cuda import SDPBackend +import torch.nn.functional as F +import torch.utils.benchmark as benchmark + +from sgm.modules.attention import SpatialTransformer, BasicTransformerBlock + + +def benchmark_attn(): + # Lets define a helpful benchmarking function: + # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html + device = "cuda" if torch.cuda.is_available() else "cpu" + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + # Lets define the hyper-parameters of our input + batch_size = 32 + max_sequence_len = 1024 + num_heads = 32 + embed_dimension = 32 + + dtype = torch.float16 + + query = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + key = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + value = torch.rand( + batch_size, + num_heads, + max_sequence_len, + embed_dimension, + device=device, + dtype=dtype, + ) + + print(f"q/k/v shape:", query.shape, key.shape, value.shape) + + # Lets explore the speed of each of the 3 implementations + from torch.backends.cuda import SDPBackend, sdp_kernel + + # Helpful arguments mapper + backend_map = { + SDPBackend.MATH: { + "enable_math": True, + "enable_flash": False, + "enable_mem_efficient": False, + }, + SDPBackend.FLASH_ATTENTION: { + "enable_math": False, + "enable_flash": True, + "enable_mem_efficient": False, + }, + SDPBackend.EFFICIENT_ATTENTION: { + "enable_math": False, + "enable_flash": False, + "enable_mem_efficient": True, + }, + } + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + print( + f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Default detailed stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + print( + f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + with sdp_kernel(**backend_map[SDPBackend.MATH]): + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("Math implmentation stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): + try: + print( + f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("FlashAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("FlashAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): + try: + print( + f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" + ) + except RuntimeError: + print("EfficientAttention is not supported. See warnings for reasons.") + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("EfficientAttention stats"): + for _ in range(25): + o = F.scaled_dot_product_attention(query, key, value) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + + +def run_model(model, x, context): + return model(x, context) + + +def benchmark_transformer_blocks(): + device = "cuda" if torch.cuda.is_available() else "cpu" + import torch.utils.benchmark as benchmark + + def benchmark_torch_function_in_microseconds(f, *args, **kwargs): + t0 = benchmark.Timer( + stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} + ) + return t0.blocked_autorange().mean * 1e6 + + checkpoint = True + compile = False + + batch_size = 32 + h, w = 64, 64 + context_len = 77 + embed_dimension = 1024 + context_dim = 1024 + d_head = 64 + + transformer_depth = 4 + + n_heads = embed_dimension // d_head + + dtype = torch.float16 + + model_native = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + use_checkpoint=checkpoint, + attn_type="softmax", + depth=transformer_depth, + sdp_backend=SDPBackend.FLASH_ATTENTION, + ).to(device) + model_efficient_attn = SpatialTransformer( + embed_dimension, + n_heads, + d_head, + context_dim=context_dim, + use_linear=True, + depth=transformer_depth, + use_checkpoint=checkpoint, + attn_type="softmax-xformers", + ).to(device) + if not checkpoint and compile: + print("compiling models") + model_native = torch.compile(model_native) + model_efficient_attn = torch.compile(model_efficient_attn) + + x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) + c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) + + from torch.profiler import ProfilerActivity, profile, record_function + + activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] + + with torch.autocast("cuda"): + print( + f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" + ) + print( + f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" + ) + + print(75 * "+") + print("NATIVE") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("NativeAttention stats"): + for _ in range(25): + model_native(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") + + print(75 * "+") + print("Xformers") + print(75 * "+") + torch.cuda.reset_peak_memory_stats() + with profile( + activities=activities, record_shapes=False, profile_memory=True + ) as prof: + with record_function("xformers stats"): + for _ in range(25): + model_efficient_attn(x, c) + print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) + print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") + + +def test01(): + # conv1x1 vs linear + from sgm.util import count_params + + conv = torch.nn.Conv2d(3, 32, kernel_size=1).cuda() + print(count_params(conv)) + linear = torch.nn.Linear(3, 32).cuda() + print(count_params(linear)) + + print(conv.weight.shape) + + # use same initialization + linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) + linear.bias = torch.nn.Parameter(conv.bias) + + print(linear.weight.shape) + + x = torch.randn(11, 3, 64, 64).cuda() + + xr = einops.rearrange(x, "b c h w -> b (h w) c").contiguous() + print(xr.shape) + out_linear = linear(xr) + print(out_linear.mean(), out_linear.shape) + + out_conv = conv(x) + print(out_conv.mean(), out_conv.shape) + print("done with test01.\n") + + +def test02(): + # try cosine flash attention + import time + + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + torch.backends.cudnn.benchmark = True + print("testing cosine flash attention...") + DIM = 1024 + SEQLEN = 4096 + BS = 16 + + print(" softmax (vanilla) first...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="softmax", + ).cuda() + try: + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + except RuntimeError as e: + # likely oom + print(str(e)) + + print("\n now flash-cosine...") + model = BasicTransformerBlock( + dim=DIM, + n_heads=16, + d_head=64, + dropout=0.0, + context_dim=None, + attn_mode="flash-cosine", + ).cuda() + x = torch.randn(BS, SEQLEN, DIM).cuda() + tic = time.time() + y = model(x) + toc = time.time() + print(y.shape, toc - tic) + print("done with test02.\n") + + +if __name__ == "__main__": + # test01() + # test02() + # test03() + + # benchmark_attn() + benchmark_transformer_blocks() + + print("done.") diff --git a/sgm/modules/attention.py b/sgm/modules/attention.py index 4cd2908..9340ddf 100644 --- a/sgm/modules/attention.py +++ b/sgm/modules/attention.py @@ -605,317 +605,3 @@ class SpatialTransformer(nn.Module): if not self.use_linear: x = self.proj_out(x) return x + x_in - - -def benchmark_attn(): - # Lets define a helpful benchmarking function: - # https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html - device = "cuda" if torch.cuda.is_available() else "cpu" - import torch.nn.functional as F - import torch.utils.benchmark as benchmark - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - # Lets define the hyper-parameters of our input - batch_size = 32 - max_sequence_len = 1024 - num_heads = 32 - embed_dimension = 32 - - dtype = torch.float16 - - query = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - key = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - value = torch.rand( - batch_size, - num_heads, - max_sequence_len, - embed_dimension, - device=device, - dtype=dtype, - ) - - print(f"q/k/v shape:", query.shape, key.shape, value.shape) - - # Lets explore the speed of each of the 3 implementations - from torch.backends.cuda import SDPBackend, sdp_kernel - - # Helpful arguments mapper - backend_map = { - SDPBackend.MATH: { - "enable_math": True, - "enable_flash": False, - "enable_mem_efficient": False, - }, - SDPBackend.FLASH_ATTENTION: { - "enable_math": False, - "enable_flash": True, - "enable_mem_efficient": False, - }, - SDPBackend.EFFICIENT_ATTENTION: { - "enable_math": False, - "enable_flash": False, - "enable_mem_efficient": True, - }, - } - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - print( - f"The default implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Default detailed stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - print( - f"The math implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - with sdp_kernel(**backend_map[SDPBackend.MATH]): - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("Math implmentation stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]): - try: - print( - f"The flash attention implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("FlashAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("FlashAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - try: - print( - f"The memory efficient implementation runs in {benchmark_torch_function_in_microseconds(F.scaled_dot_product_attention, query, key, value):.3f} microseconds" - ) - except RuntimeError: - print("EfficientAttention is not supported. See warnings for reasons.") - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("EfficientAttention stats"): - for _ in range(25): - o = F.scaled_dot_product_attention(query, key, value) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - -def run_model(model, x, context): - return model(x, context) - - -def benchmark_transformer_blocks(): - device = "cuda" if torch.cuda.is_available() else "cpu" - import torch.utils.benchmark as benchmark - - def benchmark_torch_function_in_microseconds(f, *args, **kwargs): - t0 = benchmark.Timer( - stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f} - ) - return t0.blocked_autorange().mean * 1e6 - - checkpoint = True - compile = False - - batch_size = 32 - h, w = 64, 64 - context_len = 77 - embed_dimension = 1024 - context_dim = 1024 - d_head = 64 - - transformer_depth = 4 - - n_heads = embed_dimension // d_head - - dtype = torch.float16 - - model_native = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - use_checkpoint=checkpoint, - attn_type="softmax", - depth=transformer_depth, - sdp_backend=SDPBackend.FLASH_ATTENTION, - ).to(device) - model_efficient_attn = SpatialTransformer( - embed_dimension, - n_heads, - d_head, - context_dim=context_dim, - use_linear=True, - depth=transformer_depth, - use_checkpoint=checkpoint, - attn_type="softmax-xformers", - ).to(device) - if not checkpoint and compile: - print("compiling models") - model_native = torch.compile(model_native) - model_efficient_attn = torch.compile(model_efficient_attn) - - x = torch.rand(batch_size, embed_dimension, h, w, device=device, dtype=dtype) - c = torch.rand(batch_size, context_len, context_dim, device=device, dtype=dtype) - - from torch.profiler import ProfilerActivity, profile, record_function - - activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA] - - with torch.autocast("cuda"): - print( - f"The native model runs in {benchmark_torch_function_in_microseconds(model_native.forward, x, c):.3f} microseconds" - ) - print( - f"The efficientattn model runs in {benchmark_torch_function_in_microseconds(model_efficient_attn.forward, x, c):.3f} microseconds" - ) - - print(75 * "+") - print("NATIVE") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("NativeAttention stats"): - for _ in range(25): - model_native(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by native block") - - print(75 * "+") - print("Xformers") - print(75 * "+") - torch.cuda.reset_peak_memory_stats() - with profile( - activities=activities, record_shapes=False, profile_memory=True - ) as prof: - with record_function("xformers stats"): - for _ in range(25): - model_efficient_attn(x, c) - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - print(torch.cuda.max_memory_allocated() * 1e-9, "GB used by xformers block") - - -def test01(): - # conv1x1 vs linear - from ..util import count_params - - conv = nn.Conv2d(3, 32, kernel_size=1).cuda() - print(count_params(conv)) - linear = torch.nn.Linear(3, 32).cuda() - print(count_params(linear)) - - print(conv.weight.shape) - - # use same initialization - linear.weight = torch.nn.Parameter(conv.weight.squeeze(-1).squeeze(-1)) - linear.bias = torch.nn.Parameter(conv.bias) - - print(linear.weight.shape) - - x = torch.randn(11, 3, 64, 64).cuda() - - xr = rearrange(x, "b c h w -> b (h w) c").contiguous() - print(xr.shape) - out_linear = linear(xr) - print(out_linear.mean(), out_linear.shape) - - out_conv = conv(x) - print(out_conv.mean(), out_conv.shape) - print("done with test01.\n") - - -def test02(): - # try cosine flash attention - import time - - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True - torch.backends.cudnn.benchmark = True - print("testing cosine flash attention...") - DIM = 1024 - SEQLEN = 4096 - BS = 16 - - print(" softmax (vanilla) first...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="softmax", - ).cuda() - try: - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - except RuntimeError as e: - # likely oom - print(str(e)) - - print("\n now flash-cosine...") - model = BasicTransformerBlock( - dim=DIM, - n_heads=16, - d_head=64, - dropout=0.0, - context_dim=None, - attn_mode="flash-cosine", - ).cuda() - x = torch.randn(BS, SEQLEN, DIM).cuda() - tic = time.time() - y = model(x) - toc = time.time() - print(y.shape, toc - tic) - print("done with test02.\n") - - -if __name__ == "__main__": - # test01() - # test02() - # test03() - - # benchmark_attn() - benchmark_transformer_blocks() - - print("done.")