File: check.py

package info (click to toggle)
nvidia-cutlass 3.4.1%2Bds-2
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 48,488 kB
  • sloc: cpp: 206,571; ansic: 69,215; python: 25,487; sh: 16; makefile: 15
file content (269 lines) | stat: -rw-r--r-- 12,125 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""
Utility functions for checking constraints on kernels and calculating kernel attributes
"""

import ctypes

from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC

import cutlass
from cutlass.backend.library import TileDescription


def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int:
    """
    Returns the amount of shared memory in bytes consumed in a single stage of a kernel.

    :param td: tile description to compute shared memory of
    :type td: TileDescription
    :param operation_kind: identifier for the type of operation being performed
    :type operation_kind: cutlass_library.OperationKind

    :return: number of bytes of shared memory consumed by a single stage
    :rtype: int
    """
    m, n, k = td.threadblock_shape

    if operation_kind == OperationKind.Gemm:
        stage_barrier_bytes = 32
        return (
            (DataTypeSize[td.math_instruction.element_a] * m * k // 8)
            + (DataTypeSize[td.math_instruction.element_b] * k * n // 8)
            + stage_barrier_bytes
        )
    else:
        raise Exception(f"No available shared memory calculation for operation kind {operation.operation_kind}")


def calculate_smem_usage(operation) -> int:
    """
    Returns the amount of shared memory in bytes consumed by a kernel.

    :return: number of bytes of shared memory consumed by the operation
    :return: int
    """
    _per_stage = calculate_smem_usage_per_stage(operation.tile_description, operation.operation_kind)
    return _per_stage * operation.tile_description.stages


def valid_stage_count(
    cc: int,
    kernel_cc: int,
    td: TileDescription,
    element_C: cutlass.DataType = None,
    element_D: cutlass.DataType = None,
    verbose: bool = True) -> tuple:
    """
    Checks whether a device with `cc` supports the number of stages within `tile_description`, both
    based on raw limits on the number of stages and based on shared memory capacity

    :param cc: compute capability of device in question
    :type cc: int
    :param kernel_cc: compute capability that the kernel targets (corresponding to the arch::SMxy tag in CUTLASS)
    :type kernel_cc: int
    :param td: tile description to check
    :type td: TileDescription
    :param element_C: data type of operand C
    :type element_C: cutlass.DataType
    :param element_D: data type of operand D
    :type element_D: cutlass.DataType
    :param verbose: whether to log warnings
    :type verbose: bool

    :return: tuple with the first element indicating whether the provided tile description is
             valid for the provided device and the second element being an error message
    :rtype: tuple
    """
    if kernel_cc == 90:
        if (td.stages is None or td.stages == 0):
            # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically
            # determines the stage count to use. Thus, all settings are valid in these scenarios.
            return (True, "")
        elif verbose:
            cutlass.logger.warning(
                "Setting an explicit stage count for SM90 kernels currently may "
                "result in compilation errors if the combination of tile shape, "
                "stage count, and shared memory requirement of the epilogue exceeds "
                "the available shared memory per SM.")

    if td.stages <= 0:
        return (False, f"Stage counts must be positive integers. Tile description has stage count of {td.stages}.")

    if cc < 80 and td.stages != 2:
        return (False, f"Tile description has stage count of {td.stages}, "
                       f"but only 2 stages are supported on SM{cc}.")

    # The calculation below does not consider shared memory used by the epilogue and, thus,
    # only catches cases in which the mainloop exceeds the device's shared memory capacity.
    # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the
    # mainloop and epilogue is shared.
    smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm)
    smem_usage_mainloop = (smem_per_stage * td.stages)
    smem_arch = SharedMemPerCC[cc] << 10
    if smem_usage_mainloop > smem_arch:
        return ( False,
            "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n"
            f"Details:\n"
            f"Mainloop uses {smem_per_stage} bytes of shared memory per stage, and "
            f"{td.stages} stages for a total of {smem_usage_mainloop} bytes.\n"
            f"The maxmium amount of shared memory that can be used per block on CC {cc} is {smem_arch}.")

    return (True, "")


def valid_cluster_shape(cc: int, cluster_shape: list) -> tuple:
    """
    Checks whether a device with `cc` supports a thread block cluster of shape `cluster_shape`.

    :param cc: compute capability of device in question
    :type cc: int
    :param cluster_shape: dimensions of thread block cluster shape to check
    :type cluster_shape: list

    :return: tuple with the first element indicating whether the provided cluster shape is
             valid for the provided device and the second element being an error message
    :rtype: tuple
    """

    if cc < 90:
        if cluster_shape != [1, 1, 1]:
            return (False,
                    f"Cluster shape for pre-SM90 architectures must be [1, 1, 1]. Received cluster shape of "
                    f"{cluster_shape} for SM{cc}.")
        else:
            return (True, "")

    if len(cluster_shape) != 3:
        return (False,
                f"Cluster shapes must be rank-3. Received {cluster_shape} (rank {len(cluster_shape)}")

    if cluster_shape[2] != 1:
        return (False,
                "CUTLASS kernels currently require the third dimension of cluster shape to be 1. "
                f"Received cluster shape of {cluster_shape}.")

    # The CUDA programming guide currently defines a maximum of 8 thread blocks per cluster
    # as being portably supported (https://docs.nvidia.com/cuda/cuda-c-programming-guide/#thread-block-clusters).
    # Current CUTLASS kernels only have non-unit cluster dimensions within the first two dimensions,
    # so we check that the first two dimensions of the cluster shape do not exceed 8 thread blocks in total.
    blocks_in_2d = cluster_shape[0] * cluster_shape[1]
    if blocks_in_2d > 8:
        return (False,
            f"Thread block clusters with more than 8 thread blocks are currently unsupported on SM{cc}. "
            f"Received cluster shape {cluster_shape}, which has {blocks_in_2d} thread blocks.")
    return (True, "")


def valid_schedule(
    cc: int,
    kernel_schedule: cutlass.KernelScheduleType,
    epilogue_schedule: cutlass.EpilogueScheduleType,
    tile_scheduler: cutlass.TileSchedulerType) -> tuple:
    """
    Checks that the kernel and epilogue schedules passed in are a valid combination for
    a device of compute capability ``cc``.

    :param cc: compute capability of device in question
    :type cc: int
    :param kernel_schedule: kernel schedule type
    :type kernel_schedule: cutlass.KernelScheduleType
    :param epilogue_schedule: epilogue schedule type
    :type epilogue_schedule: cutlass.EpilogueScheduleType
    :param tile_scheduler: tile scheduler type
    :type tile_scheduler: cutlass.TileSchedulerType

    :return: tuple with the first element indicating whether the provided schedules are
             valid for the provided device and the second element being an error message
    :rtype: tuple
    """
    kernel_auto = (kernel_schedule == cutlass.KernelScheduleType.ScheduleAuto)
    epilogue_auto = (epilogue_schedule == cutlass.EpilogueScheduleType.ScheduleAuto)
    tile_scheduler_default = (tile_scheduler == cutlass.TileSchedulerType.Default)
    if cc < 90 and not (kernel_auto and epilogue_auto and tile_scheduler_default):
        return (False, "Non-default schedules are only supported on SM90 and beyond")

    if (kernel_auto and not epilogue_auto) or (not kernel_auto and epilogue_auto):
        return (False, "Kernel and epilogue schedules must either both be auto or neither be auto")

    if not tile_scheduler_default:
        cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, 
                               cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative]
        if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels):
            return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule")
    return (True, "")


def alignment_or_default(alignment_provided: int, default_alignment: int) -> int:
    """
    Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
    that `alignment_provided` does not exceed `default_alignment`.

    :param alignment_provided: alignment preference specified. Can be None.
    :type alignment_provided: int
    :param default_alignment: alignment to use if `alignment_provided` is None
    :type default_alignment: int

    :return: alignment to use
    :rtype: int
    """
    if alignment_provided is not None:
        if alignment_provided > default_alignment:
            raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
        return alignment_provided

    return default_alignment


def update_alignment(alignment_provided:int, default_alignment: int) -> int:
    """
    Returns `alignment_provided` if it is set, otherwise `default_alignment` and checks
    that `alignment_provided` does not exceed `default_alignment`.

    :param alignment_provided: alignment preference specified. Can be None.
    :type alignment_provided: int
    :param default_alignment: alignment to use if `alignment_provided` is None
    :type default_alignment: int

    :return: alignment to use
    :rtype: int
    """
    if alignment_provided is not None:
        if alignment_provided > default_alignment:
            if alignment_provided % default_alignment == 0:
                return default_alignment
            raise Exception(f"Alignment {alignment_provided} exceeds the maximum supported of {default_alignment}.")
        return alignment_provided

    return default_alignment