File: concat.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (137 lines) | stat: -rw-r--r-- 4,203 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import numpy as np

import torch

from . import benchmark


class Concat2D2InputBench(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
        super().__init__(mode, device, dtype)
        self.I1_D1 = I1_D1
        self.I1_D2 = I1_D2
        self.I2_D1 = I2_D1
        self.I2_D2 = I2_D2
        self.concat_dim = concat_dim
        self.input1 = self.randn(
            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.input2 = self.randn(
            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [self.input1, self.input2]

    def forward(self, input1, input2):
        x1 = self.add(input1, 0.00001)
        x2 = self.add(input2, 0.00001)
        y = self.cat((x1, x2), dim=self.concat_dim)
        return y

    def reference(self):
        return np.concatenate(
            (self.numpy(self.input1), self.numpy(self.input2)),
            axis=self.concat_dim,
        )

    def config(self):
        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]

    @staticmethod
    def module():
        return "concat2d2input"

    def memory_workload(self):
        if self.mode == "fwd":
            sol_count = 1 + 1
            algorithmic_count = 3 + 1
        else:
            sol_count = (1 + 1) + (1 + 1)
            algorithmic_count = (3 + 1) + (3 + 1)

        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def default_configs():
        return [
            [1, 160, 1, 14, 1],
            [1, 580, 1, 174, 1],
            [20, 160, 20, 14, 1],
            [20, 580, 20, 174, 1],
            [8, 512, 8, 512, 1],
            [1 << 13, 1060, 1 << 13, 1040, 1],
            [1 << 13, 2000, 1 << 13, 1074, 1],
            [1 << 15, 1060, 1 << 15, 2670, 1],
            [1 << 15, 5120, 1 << 15, 2512, 1],
        ]


benchmark.register_benchmark_class(Concat2D2InputBench)


class ConcatGraphOptBench(benchmark.Benchmark):
    def __init__(self, mode, device, dtype, I1_D1, I1_D2, I2_D1, I2_D2, concat_dim):
        super().__init__(mode, device, dtype)
        self.I1_D1 = I1_D1
        self.I1_D2 = I1_D2
        self.I2_D1 = I2_D1
        self.I2_D2 = I2_D2
        self.concat_dim = concat_dim
        self.input1 = self.randn(
            [I1_D1, I1_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.input2 = self.randn(
            [I2_D1, I2_D2], device=device, dtype=dtype, requires_grad=self.requires_grad
        )
        self.inputs = [self.input1, self.input2]
        torch._C._jit_override_can_fuse_on_cpu(True)
        torch._C._jit_cat_wo_conditionals(True)

    def forward(self, input1, input2):
        x1 = self.add(input1, 0.00001)
        x2 = self.add(input2, 0.00001)
        y = self.cat((x1, x2), dim=self.concat_dim)
        z = self.relu(y)
        return z

    def reference(self):
        return np.concatenate(
            (self.numpy(self.input1), self.numpy(self.input2)),
            axis=self.concat_dim,
        )

    def config(self):
        return [self.I1_D1, self.I1_D2, self.I2_D1, self.I2_D2, self.concat_dim]

    @staticmethod
    def module():
        return "concatGraphOpt"

    def memory_workload(self):
        if self.mode == "fwd":
            sol_count = 1 + 1
            algorithmic_count = 3 + 1
        else:
            sol_count = (1 + 1) + (1 + 1)
            algorithmic_count = (3 + 1) + (3 + 1)

        buffer_size = self.I1_D1 * self.I1_D2 + self.I2_D1 * self.I2_D2
        return {
            "sol": buffer_size * sol_count,
            "algorithmic": buffer_size * algorithmic_count,
        }

    @staticmethod
    def default_configs():
        return [
            [1 << 13, 1060, 1 << 13, 1040, 1],
            [1 << 13, 2000, 1 << 13, 1074, 1],
            [1 << 15, 1060, 1 << 15, 2670, 1],
            [1 << 15, 5120, 1 << 15, 2512, 1],
        ]


benchmark.register_benchmark_class(ConcatGraphOptBench)