1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
import time
import torch
from torch.nn.functional import multi_head_attention_forward as mha_forward
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct
def benchmark_mha_block():
def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None):
# Build torchtext MultiheadAttention module
in_proj_container = InProjContainer(
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.Linear(embed_dim, embed_dim),
)
MHA = MultiheadAttentionContainer(
nhead, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)
).to(device)
query = torch.rand((tgt_len, bsz, embed_dim)).to(device)
if src_len is None:
key = value = query
src_len = tgt_len
else:
key = value = torch.rand((src_len, bsz, embed_dim)).to(device)
attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device)
attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device)
print("starting torchtext.modules.MultiheadAttentionContainer")
if device == torch.device("cuda"):
torch.cuda.synchronize()
t0 = time.monotonic()
for _ in range(100):
mha_output, attn_weights = MHA(
query,
key,
value,
attn_mask=attn_mask,
bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
)
if device == torch.device("cuda"):
torch.cuda.synchronize()
print(time.monotonic() - t0)
# Use torch.nn.functional.multi_head_attention_forward
torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float("-inf"))
print("starting torch.nn.functional.multi_head_attention_forward")
in_proj_weight = torch.cat(
[
MHA.in_proj_container.query_proj.weight,
MHA.in_proj_container.key_proj.weight,
MHA.in_proj_container.value_proj.weight,
]
)
if device == torch.device("cuda"):
torch.cuda.synchronize()
t0 = time.monotonic()
for _ in range(100):
torch_mha_output, torch_mha_weights = mha_forward(
query,
key,
value,
embed_dim,
nhead,
in_proj_weight,
None,
bias_k,
bias_v,
False,
0.0,
MHA.out_proj.weight,
MHA.out_proj.bias,
attn_mask=torch_attn_mask,
)
if device == torch.device("cuda"):
torch.cuda.synchronize()
print(time.monotonic() - t0)
# GPU test
device = torch.device("cuda")
for embed_dim in [64, 768]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print("test case GPU with embed_dim, nhead, seq_len, bsz:", embed_dim, nhead, seq_len, seq_len, bsz)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len)
# GPU test for self-attention
device = torch.device("cuda")
for embed_dim in [64, 256]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print(
"self-attention test case GPU with embed_dim, nhead, seq_len, bsz:",
embed_dim,
nhead,
seq_len,
seq_len,
bsz,
)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)
# CPU test for self-attention
device = torch.device("cpu")
for embed_dim in [64, 768]:
for nhead in [2, 16]:
for seq_len in [10, 128, 1000]:
for bsz in [2, 72]:
if seq_len == 1000 and bsz == 72:
continue
print("*" * 80)
print("test case CPU with embed_dim, nhead, seq_len, bsz:", embed_dim, nhead, seq_len, seq_len, bsz)
_run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)
if __name__ == "__main__":
benchmark_mha_block()
|