File: attention.py

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (90 lines) | stat: -rw-r--r-- 2,871 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded
# for benchmarking and some control flow stripped out.
# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py

from . import benchmark
import torch


class BahdanauAttention(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, b, t_q, t_k, n):
        super().__init__(mode, device, dtype)
        self.b = b
        self.t_q = t_q
        self.t_k = t_k
        self.n = n
        self.att_query = self.rand(
            [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.att_keys = self.rand(
            [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.normalize_bias = self.rand(
            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.linear_att = self.rand(
            [n], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [
            self.att_query,
            self.att_keys,
            self.normalize_bias,
            self.linear_att,
        ]

    def forward(self, att_query, att_keys, normalize_bias, linear_att):
        """
        Calculate Bahdanau score

        :param att_query: b x t_q x n
        :param att_keys: b x t_k x n

        return b x t_q x t_k scores
        """

        b, t_k, n = att_keys.size()
        t_q = att_query.size(1)

        att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n)
        att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n)
        sum_qk = att_query + att_keys + normalize_bias
        out = torch.tanh(sum_qk).matmul(linear_att)
        return out

    def reference(self):
        return self.numpy(self.forward(*self.inputs))

    def config(self):
        return [self.b, self.t_q, self.t_k, self.n]

    @staticmethod
    def module():
        return "attention"

    def memory_workload(self):
        def memsize(t):
            return t.numel() * t.element_size()

        input_size = (
            memsize(self.att_query)
            + memsize(self.att_keys)
            + memsize(self.normalize_bias)
            + memsize(self.linear_att)
        )
        output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel()
        io_size = input_size + output_size

        # If matmul is not fused, must write and then read `sum_qk`.
        intermediate_size = (
            2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel()
        )
        return {"sol": io_size, "algorithmic": io_size + intermediate_size}

    @staticmethod
    def default_configs():
        mlperf_inference = [1280, 1, 66, 1024]
        nvidia = [128, 10, 128, 1024]
        return [mlperf_inference, nvidia]


benchmark.register_benchmark_class(BahdanauAttention)