File: gemm_grouped.py

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (264 lines) | stat: -rw-r--r-- 12,455 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
    Ease-of-use interface for constructing, compiling, and running GEMMs.

    The ``GroupedGemm`` interface is meant to allow one to easily instantiate, compile, and run
    grouped GEMM operations in CUTLASS via Python, without specifying many configuration parameters.
    Under the hood, the interface will select sensible default parameters for the many template
    parameters for CUTLASS grouped GEMMs.

    Note: optimal performance is not to be expected from this interface. To achieve optimal
    performance, one should specify and tune each configuration parameter.

    The simplest example of using this interface is the following:

    .. highlight:: python
    .. code-block:: python

        # As, Bs, Cs, and Ds are torch/numpy/cupy tensor objects
        plan = cutlass.op.GroupedGemm(element=cutlass.DataType.f16, layout=cutlass.LayoutType.RowMajor)
        plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1])
"""

from cutlass_library import DataTypeSize

from cuda import cuda
from cutlass.backend.gemm_operation import (
    GemmGroupedArguments,
    GemmOperationGrouped,
)
from cutlass.backend.library import (
    SchedulerMode,
    TensorDescription,
    TileDescription,
)
from cutlass.op.gemm import Gemm
from cutlass.shape import GemmCoord
from cutlass.utils import check, datatypes


class GroupedGemm(Gemm):
    """
    Constructs a ``GroupedGemm`` object.

    The data types and layouts of operands A, B, and C, along with the data type of output D
    and that used for accumulation, are bound to the ``GroupedGemm`` object throughout its lifetime --
    these are not to be changed after a ``GroupedGemm`` has been constructed.

    The constructor has optional parameters for flexibly setting these parameters. Please see the constructor
    for ``Gemm`` for examples of these.

    :param cc: compute capability of device to generate kernels for
    :type cc: int
    :param A: tensor representing data type and layout of operands A
    :param B: tensor representing data type and layout of operands B
    :param C: tensor representing data type and layout of operands C
    :param D: tensor representing data type and layout of operands D
    :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
    :param beta: scalar parameter beta from GEMM operation that scales operand C
    :param element_accumulator: data type to be used in accumulation of the product of operands A and B
    :type element_accumulator: cutlass.DataType
    :param element: generic data type to be used for operands A, B, C, D, as well as the accumulation data type
    :type element: cutlass.DataType
    :param layout: generic layout type to be used for operands A, B, C, and D
    :type layout: cutlass.LayoutType
    :param element_A: data type to be used for operand A
    :type element_A: cutlass.DataType
    :param element_B: data type to be used for operand B
    :type element_B: cutlass.DataType
    :param element_C: data type to be used for operand C
    :type element_C: cutlass.DataType
    :param element_D: data type to be used for operand D
    :type element_D: cutlass.DataType
    :type layout_A: layout of operand A
    :param layout_A: cutlass.LayoutType
    :type layout_B: layout of operand B
    :param layout_B: cutlass.LayoutType
    :type layout_C: layout of operand C
    :param layout_C: cutlass.LayoutType
    :type layout_D: layout of operand D
    :param layout_D: cutlass.LayoutType
    """

    def __init__(
        self, A=None, B=None, C=None, D=None,
        alpha=1.0, beta=0.0, element_accumulator=None,
        element=None, layout=None,
        element_A=None, element_B=None, element_C=None, element_D=None,
        layout_A=None, layout_B=None, layout_C=None,
        cc: int = None,
    ):
        super().__init__(
            A=A, B=B, C=C, D=D,
            alpha=alpha, beta=beta,
            element_accumulator=element_accumulator,
            element=element, layout=layout,
            element_A=element_A, element_B=element_B,
            element_C=element_C, element_D=element_D,
            layout_A=layout_A, layout_B=layout_B, layout_C=layout_C,
            cc=cc
        )

        # Grouped GEMM specializations for SM90 are currently unavailable. Revert to using SM80
        if self.current_cc == 90:
            self._reset_options(80)
            self._reset_operations(reset_epilogue=False)

        self.name = "grouped_gemm"

    @Gemm.swizzling_functor.setter
    def swizzling_functor(self, swizzling_functor):
        """
        Sets the swizzling functor to the type specified by `swizzling_functor`
        """
        raise Exception('Grouped GEMM does not currently support different swizzling functors')

    def construct(self, tile_description: TileDescription = None,
                  alignment_A: int = None,
                  alignment_B: int = None,
                  alignment_C: int = None) -> GemmOperationGrouped:
        """
        Constructs a ``cutlass.backend.GemmOperationGrouped`` based on the input parameters and current
        kernel specification of the ``Gemm`` object.

        :param tile_description: tile description specifying shapes and operand types to use in the kernel
        :type tile_description: cutlass.backend.TileDescription
        :param alignment_A: alignment of operand A
        :type alignment_A: int
        :param alignment_B: alignment of operand B
        :type alignment_B: int
        :param alignment_C: alignment of operand C
        :type alignment_C: int

        :return: operation that was constructed
        :rtype: cutlass.backend.GemmOperationGrouped
        """
        alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A")))
        alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B")))
        alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C")))

        self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor)

        tensor_A = TensorDescription(self._element_a, self._layout_b, alignment_A)
        tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B)
        tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C)

        if tile_description is None:
            op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C, self._math_operation)[0]
            tile_description = datatypes.td_from_profiler_op(op)
        else:
            valid, err_str = self._valid_tile_description(tile_description)
            if not valid:
                raise Exception(f"Invalid tile description. {err_str}")
            self.tile_description = tile_description

        operation = GemmOperationGrouped(
            arch=self.current_cc,
            tile_description=tile_description,
            A=tensor_A, B=tensor_B, C=tensor_C,
            epilogue_functor=self.epilogue_functor,
            swizzling_functor=self._swizzling_functor,
            precompute_mode=SchedulerMode.Device)

        return operation

    def run(self, A, B, C, D,
            alpha=None, beta=None, sync: bool = True,
            print_module: bool = False,
            stream: cuda.CUstream = cuda.CUstream(0)) -> GemmGroupedArguments:
        """
        Runs the kernel currently specified.

        By default, this call returns only once the kernel has completed. To launch the kernel
        and immediately return, set ``sync=False``. In this case, it is the responsibility of the
        caller to syncrhonize the results of the kernel before attempting to access outputs
        by calling ``sync()`` on the arguments returned from this call.

        :param A: list of tensors representing data type and layout of operand A
        :type A: list
        :param B: list of tensors representing data type and layout of operand B
        :type B: list
        :param C: list of tensors representing data type and layout of operand C
        :type C: list
        :param D: list of tensors representing data type and layout of operand D
        :type D: list
        :param alpha: scalar paramter alpha from GEMM computation that scales the product of operands A and B
        :param beta: scalar parameter beta from GEMM operation that scales operand C
        :param sync: whether the call should wait for the kernel to complete before returning
        :type sync: bool
        :param print_module: whether to print the emitted C++ code
        :type print_module: bool
        :param stream: cuda stream, defaults to cuda.cuda.CUstream(0)
        :type stream: :class:`cuda.cuda.CUstream`

        :return: arguments passed in to the kernel
        :rtype: cutlass.backend.GemmGroupedArguments
        """
        super().run_setup()

        if len(A) != len(B) or len(A) != len(C) or len(A) != len(D):
            raise Exception("Lengths of A, B, C, and D lists must be equal")

        problem_sizes = []
        As, Bs, Cs, Ds = ([None] * len(A) for _ in range(4))
        for i in range(len(A)):
            As[i] = self._verify_tensor(A[i], self.A, self._element_a, self._layout_a, "A")
            Bs[i] = self._verify_tensor(B[i], self.B, self._element_b, self._layout_b, "B")
            Cs[i] = self._verify_tensor(C[i], self.C, self._element_c, self._layout_c, "C")
            Ds[i] = self._verify_tensor(D[i], self.D, self._element_d, self._layout_d, "D")
            problem_sizes.append(GemmCoord(A[i].shape[0], B[i].shape[1], A[i].shape[1]))

        alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
        beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")

        alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As))
        alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs))
        alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs))
        self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
                     alignment_C=alignment_c, print_module=print_module)

        arguments = GemmGroupedArguments(
            operation=self.operation,
            problem_sizes=problem_sizes,
            A=As, B=Bs, C=Cs, D=Ds,
            output_op=self.operation.epilogue_type(alpha, beta),
            stream=stream
        )

        self.operation.run(arguments)

        if sync:
            arguments.sync()

        return arguments