File: cpp_bmm_template.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (254 lines) | stat: -rw-r--r-- 9,099 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# mypy: allow-untyped-defs
import contextlib
import itertools
from typing import Any, Callable, Dict, List, Optional
from unittest.mock import patch

import sympy

from .. import ir
from ..select_algorithm import PartialRender
from ..virtualized import V
from .cpp_gemm_template import CppGemmTemplate, GEMM_TEMPLATE
from .cpp_micro_gemm import LayoutType
from .cpp_template_kernel import CppTemplateKernel
from .cpp_utils import DTYPE_TO_CPP, GemmBlocking


# We pass all sizevars present in BY to the GEMM templates so variables are not renamed in the BMM definition
GEMM_SINGLE_THREAD_MM_STUB = r"""
{{kernel.def_kernel(
    inputs={"X": X, "W": W},
    outputs={"Y": Y_2d},
    aliases=aliases,
    function_name="single_thread_mm",
    extra_sizevars=BY_sizevars + [b_index],
    placeholder="<SINGLE_THREAD_MM_DEF_FOR_BMM>")}}"""

GEMM_THREADED_MM_STUB = r"""
{{kernel.def_kernel(
    inputs={"X": X, "W": W},
    outputs={"Y": Y_2d},
    aliases=aliases,
    function_name="threaded_mm",
    extra_sizevars=BY_sizevars + [b_index],
    placeholder="<THREADED_MM_DEF_FOR_BMM>")}}"""

BMM_TEMPLATE = r"""
{{ template.codegen_microkernel_def() }}
{{ template.codegen_single_thread_gemm() }}
{{ template.codegen_multi_thread_gemm() }}

extern "C"
{{kernel.def_kernel(inputs={"X": BX, "W": BW}, outputs={"Y": BY}, aliases=aliases)}}
{
    const int64_t B = {{kernel.size(BY_2d, 0)}};
    {%- if num_threads > 1 %}
    constexpr int64_t num_threads = {{num_threads}};
    int64_t B_single_thread_block = (B / num_threads) * num_threads;

    #pragma omp parallel for num_threads({{num_threads}})
    {%- else %}
    int64_t B_single_thread_block = B;
    {%- endif %}
    for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
        {{template.get_gemm_function_call(
            kernel,
            "single_thread_mm",
            "<SINGLE_THREAD_CALL_FOR_BMM>",
            b_index="b_start",
        )}}
    }
    for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
        {{template.get_gemm_function_call(
            kernel,
            "threaded_mm",
            "<THREADED_MM_CALL_FOR_BMM>",
            b_index="b_start",
        )}}
    }
}
"""


class CppBmmTemplate(CppGemmTemplate):
    def __init__(
        self,
        input_nodes,
        layout: ir.Layout,
        num_threads: int,
        register_blocking: GemmBlocking,
        beta=1,
        alpha=1,
        has_bias=False,
        epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
        should_block_weights: bool = False,
        name="bmm",
    ):
        """
        In order to simplify the implementation and increase code reuse, the BMM template implements
        two versions of the GEMM kernel: a single-threaded version and a multi-threaded version.
        GEMM kernels are called in a loop over the batch dimension, with single-threaded GEMM calls
        for all but the last (B % num_threads), which are handled by the multi-threaded GEMM kernel.

        We use an extra sizevar `b_index` to index the batch dimension, which we pass into the GEMM
        template as a sympy.Symbol. This allows us to slice the 3D batch tensors in the GEMM template
        without any changes to the GEMM template itself.
        """
        super().__init__(
            input_nodes,
            layout,
            num_threads,
            register_blocking,
            beta=beta,
            alpha=alpha,
            has_bias=has_bias,
            epilogue_creator=epilogue_creator,
            should_block_weights=should_block_weights,
            name=name,
        )
        self.b_index = sympy.Symbol("s_b_index", integer=True, nonnegative=True)

    @staticmethod
    def get_padded_size(n, block_n, k, should_block_weight):
        if should_block_weight:
            # Tensor is constant or not contiguous, so we will pad and block
            new_size, padded_n = CppGemmTemplate.get_padded_size(
                n, block_n, k, should_block_weight
            )
            # Add the new batch dimension
            new_size.insert(0, -1)
            return new_size, padded_n
        else:
            new_size = [-1, k, n]
            return new_size, n

    @staticmethod
    def check_if_block_weight(W, micro_gemm):
        return micro_gemm.get_b_layout() != LayoutType.NORMAL or (
            (not W.get_layout().is_contiguous() or W.get_name() in V.graph.constants)  # type: ignore[union-attr]
            if isinstance(W, ir.IRNode)
            else not W.is_contiguous()
        )

    def get_gemm_function_call(
        self,
        kernel: CppTemplateKernel,
        function_name: str,
        placeholder: str,
        b_index: int,
    ) -> str:
        """
        Similar to 'def_kernel' in cpp_template_kernel, but instead of generating a function definition,
        generate a function call for the GEMM kernel.
        Args:
            placeholder: The string to replace the function call with
            b_index: The index for slicing the 3D batch tensors
        """

        def hook():
            arg_defs, call_args, _, _ = kernel.args.python_argdefs()
            for i, buf in enumerate(call_args):
                if buf == self.b_index:
                    arg_defs[i] = b_index
            call = f"{function_name}({', '.join(arg_defs)});"
            return call

        assert placeholder not in kernel.render_hooks
        kernel.render_hooks[placeholder] = hook
        return placeholder

    def get_default_reindexers(self, epilogue_nodes):
        def reindexer(args):
            # if epilogue nodes exist, they have 3D ranges but args are 2D, so add 0 index
            return [self.b_index] + args

        return [reindexer] * len(epilogue_nodes)

    def get_options(
        self,
        kernel: CppTemplateKernel,
        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
        flag_template_buffer_has_other_users: Optional[bool] = None,
        epilogue_nodes: Optional[List[ir.IRNode]] = None,
        **kwargs,
    ) -> Dict[str, Any]:
        options = super().get_options(
            kernel=kernel,
            template_buffer_node=template_buffer_node,
            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
            epilogue_nodes=epilogue_nodes,
            **kwargs,
        )

        BX, BW, BY = options["X"], options["W"], options["Y"]
        options["BX"], options["BW"], options["BY"] = BX, BW, BY
        options["BY_2d"] = options["Y_2d"]
        for kword in ["X", "W", "GemmOut", "Y_2d"]:
            options[kword] = kernel.select(options[kword], 0, self.b_index)
        for kword in ["X", "W", "Y_2d"]:
            options[kword + "_dtype"] = DTYPE_TO_CPP[options[kword].dtype]
        options["b_index"] = self.b_index
        options["BY_sizevars"] = [
            s
            for sym in itertools.chain(BY.get_size(), BY.get_stride())
            if isinstance(sym, sympy.Expr)
            for s in sym.free_symbols
        ]

        return options

    def render(  # type: ignore[override, return]
        self,
        kernel: CppTemplateKernel,
        template_buffer_node: Optional[ir.CppTemplateBuffer] = None,
        flag_template_buffer_has_other_users: Optional[bool] = None,
        epilogue_nodes: Optional[List[ir.IRNode]] = None,
        **kwargs,
    ) -> str:
        options = self.get_options(
            kernel=kernel,
            template_buffer_node=template_buffer_node,
            flag_template_buffer_has_other_users=flag_template_buffer_has_other_users,
            epilogue_nodes=epilogue_nodes,
            **kwargs,
        )
        self.render_options = options

        with contextlib.ExitStack() as stack:
            for buf in options["fake_buffers"]:
                stack.enter_context(
                    patch.object(V.graph, "get_dtype", self._fake_get_dtype(buf))
                )
            result = self._template_from_string(BMM_TEMPLATE).render(**options)

            # Finalize the function definitions for the gemm routines
            sub_mm_hooks = {
                name: hook
                for name, hook in kernel.render_hooks.items()
                if "FOR_BMM" in name
            }
            result = PartialRender(result, sub_mm_hooks).finalize_all()
            for name in sub_mm_hooks:
                del kernel.render_hooks[name]
            del kernel.args.sizevars[options["b_index"]]
            return result

    def codegen_single_thread_gemm(self):
        stub = self._template_from_string(GEMM_SINGLE_THREAD_MM_STUB).render(
            self.render_options
        )
        return stub + self._template_from_string(GEMM_TEMPLATE).render(
            {**self.render_options, "num_threads": 1}
        )

    def codegen_multi_thread_gemm(self):
        stub = self._template_from_string(GEMM_THREADED_MM_STUB).render(
            self.render_options
        )
        return stub + self._template_from_string(GEMM_TEMPLATE).render(
            self.render_options
        )

    def codegen_gemm_stub_def(self):
        return ""