File: mha_block.py

package info (click to toggle)
pytorch-text 0.14.1-2
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 11,560 kB
  • sloc: python: 14,197; cpp: 2,404; sh: 214; makefile: 20
file content (125 lines) | stat: -rw-r--r-- 4,863 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import time

import torch
from torch.nn.functional import multi_head_attention_forward as mha_forward
from torchtext.modules import InProjContainer, MultiheadAttentionContainer, ScaledDotProduct


def benchmark_mha_block():
    def _run_benchmark(embed_dim, nhead, bsz, device, tgt_len, src_len=None):
        # Build torchtext MultiheadAttention module
        in_proj_container = InProjContainer(
            torch.nn.Linear(embed_dim, embed_dim),
            torch.nn.Linear(embed_dim, embed_dim),
            torch.nn.Linear(embed_dim, embed_dim),
        )
        MHA = MultiheadAttentionContainer(
            nhead, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)
        ).to(device)

        query = torch.rand((tgt_len, bsz, embed_dim)).to(device)
        if src_len is None:
            key = value = query
            src_len = tgt_len
        else:
            key = value = torch.rand((src_len, bsz, embed_dim)).to(device)
        attn_mask_2D = torch.randint(0, 2, (tgt_len, src_len)).to(torch.bool).to(device)
        attn_mask = torch.stack([attn_mask_2D] * (bsz * nhead))
        bias_k = bias_v = torch.rand((1, 1, embed_dim)).to(device)
        print("starting torchtext.modules.MultiheadAttentionContainer")
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        t0 = time.monotonic()
        for _ in range(100):
            mha_output, attn_weights = MHA(
                query,
                key,
                value,
                attn_mask=attn_mask,
                bias_k=bias_k.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
                bias_v=bias_v.repeat(1, bsz, 1).reshape(1, bsz * nhead, -1),
            )
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        print(time.monotonic() - t0)

        # Use torch.nn.functional.multi_head_attention_forward
        torch_attn_mask = torch.zeros((tgt_len, src_len)).to(device).masked_fill_(attn_mask_2D, float("-inf"))
        print("starting torch.nn.functional.multi_head_attention_forward")
        in_proj_weight = torch.cat(
            [
                MHA.in_proj_container.query_proj.weight,
                MHA.in_proj_container.key_proj.weight,
                MHA.in_proj_container.value_proj.weight,
            ]
        )
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        t0 = time.monotonic()
        for _ in range(100):
            torch_mha_output, torch_mha_weights = mha_forward(
                query,
                key,
                value,
                embed_dim,
                nhead,
                in_proj_weight,
                None,
                bias_k,
                bias_v,
                False,
                0.0,
                MHA.out_proj.weight,
                MHA.out_proj.bias,
                attn_mask=torch_attn_mask,
            )
        if device == torch.device("cuda"):
            torch.cuda.synchronize()
        print(time.monotonic() - t0)

    # GPU test
    device = torch.device("cuda")
    for embed_dim in [64, 768]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print("test case GPU with embed_dim, nhead, seq_len, bsz:", embed_dim, nhead, seq_len, seq_len, bsz)
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, seq_len)

    # GPU test for self-attention
    device = torch.device("cuda")
    for embed_dim in [64, 256]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print(
                        "self-attention test case GPU with embed_dim, nhead, seq_len, bsz:",
                        embed_dim,
                        nhead,
                        seq_len,
                        seq_len,
                        bsz,
                    )
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)

    # CPU test for self-attention
    device = torch.device("cpu")
    for embed_dim in [64, 768]:
        for nhead in [2, 16]:
            for seq_len in [10, 128, 1000]:
                for bsz in [2, 72]:
                    if seq_len == 1000 and bsz == 72:
                        continue
                    print("*" * 80)
                    print("test case CPU with embed_dim, nhead, seq_len, bsz:", embed_dim, nhead, seq_len, seq_len, bsz)
                    _run_benchmark(embed_dim, nhead, bsz, device, seq_len, None)


if __name__ == "__main__":
    benchmark_mha_block()