File: device_op_overrides.py

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (244 lines) | stat: -rw-r--r-- 9,417 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# mypy: allow-untyped-defs
import torch

from ..common import DeviceOpOverrides, register_device_op_overrides


class CUDADeviceOpOverrides(DeviceOpOverrides):
    def import_get_raw_stream_as(self, name):
        return f"from torch._C import _cuda_getCurrentRawStream as {name}"

    def set_device(self, device_idx):
        return f"torch.cuda.set_device({device_idx})"

    def synchronize(self):
        return "torch.cuda.synchronize()"

    def device_guard(self, device_idx):
        return f"torch.cuda._DeviceGuard({device_idx})"

    def cpp_device_guard(self):
        return "at::cuda::CUDAGuard"

    def cpp_aoti_device_guard(self):
        return "AOTICudaGuard"

    def cpp_stream_guard(self):
        return "at::cuda::CUDAStreamGuard"

    def cpp_aoti_stream_guard(self):
        return "AOTICudaStreamGuard"

    def cpp_getStreamFromExternal(self):
        return "at::cuda::getStreamFromExternal"

    def kernel_header(self):
        source_codes = """
        #include <c10/cuda/CUDAGuard.h>
        #include <c10/cuda/CUDAStream.h>
        #include <ATen/cuda/EmptyTensor.h>
        """
        return source_codes

    def kernel_driver(self):
        source_codes = """
            #define CUDA_DRIVER_CHECK(EXPR)                    \\
            do {                                               \\
                CUresult code = EXPR;                          \\
                const char *msg;                               \\
                CUresult code_get_error = cuGetErrorString(code, &msg); \\
                if (code_get_error != CUDA_SUCCESS) {          \\
                    throw std::runtime_error(                  \\
                        std::string("CUDA driver error: ") +   \\
                        std::string("invalid error code!"));   \\
                }                                              \\
                if (code != CUDA_SUCCESS) {                    \\
                    throw std::runtime_error(                  \\
                        std::string("CUDA driver error: ") +   \\
                        std::string(msg));                     \\
                }                                              \\
            } while (0);

            namespace {

            struct Grid {
                Grid(uint32_t x, uint32_t y, uint32_t z)
                  : grid_x(x), grid_y(y), grid_z(z) {}
                uint32_t grid_x;
                uint32_t grid_y;
                uint32_t grid_z;

                bool is_non_zero() {
                    return grid_x > 0 && grid_y > 0 && grid_z > 0;
                }
            };

            }  // anonymous namespace

            static inline CUfunction loadKernel(
                    std::string filePath,
                    const std::string &funcName,
                    uint32_t sharedMemBytes,
                    const std::optional<std::string> &cubinDir = std::nullopt) {
                if (cubinDir) {
                    std::filesystem::path p1{*cubinDir};
                    std::filesystem::path p2{filePath};
                    filePath = (p1 / p2.filename()).string();
                }

                CUmodule mod;
                CUfunction func;
                CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
                CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
                if (sharedMemBytes > 0) {
                    CUDA_DRIVER_CHECK(cuFuncSetAttribute(
                        func,
                        CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
                        sharedMemBytes
                    ))
                }
                return func;
            }

            static inline void launchKernel(
                    CUfunction func,
                    uint32_t gridX,
                    uint32_t gridY,
                    uint32_t gridZ,
                    uint32_t numWarps,
                    uint32_t sharedMemBytes,
                    void* args[],
                    cudaStream_t stream) {
                CUDA_DRIVER_CHECK(cuLaunchKernel(
                    func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
                ));
            }
        """
        if torch.version.hip is not None:
            # Adjusting the warp size to GPU supported wavefront size on AMD GPU
            prop = torch.cuda.get_device_properties(torch.cuda.current_device())
            source_codes = source_codes.replace(
                "32*numWarps", str(prop.warp_size) + "*numWarps"
            )
        return source_codes

    def tma_descriptor_helpers(self):
        if torch.version.hip is not None:
            raise RuntimeError("Host-side TMA descriptors not supported on HIP.")

        # helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here:
        # https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283
        return """
            #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
            [[maybe_unused]] static void init1DTMADescriptor(
                    CUtensorMap* m,
                    void* globalAddress,
                    uint64_t dim,
                    uint32_t blockDim,
                    uint32_t elementSize) {
                uint64_t dims[1] = {dim};
                uint64_t globalStrides[1] = {dim * elementSize};
                uint32_t tensorDims[1] = {blockDim};
                uint32_t elementStrides[1] = {1};

                CUtensorMapDataType type;
                switch (elementSize) {
                case 1:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
                    break;
                case 2:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
                    break;
                case 4:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
                    break;
                default:
                    throw std::runtime_error("elementSize must be 1, 2, or 4");
                }

                if (elementSize * blockDim < 32) {
                    throw std::runtime_error("block size too small");
                }

                int rank = 1;

                CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
                    m, type, rank, globalAddress, dims,
                    globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
                    CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
                    CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
            }

            [[maybe_unused]] static void init2DTMADescriptor(
                    CUtensorMap* m,
                    void* globalAddress,
                    uint64_t dim1,
                    uint64_t dim0,
                    uint32_t blockDim1,
                    uint32_t blockDim0,
                    uint32_t elementSize) {
                uint64_t dims[2] = {dim0, dim1};
                uint32_t tensorDims[2] = {blockDim0, blockDim1};
                uint64_t globalStrides[2] = {dims[0] * elementSize,
                                             dims[0] * dims[1] * elementSize};
                uint32_t elementStrides[2] = {1, 1};

                CUtensorMapDataType type;
                switch (elementSize) {
                case 1:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
                    break;
                case 2:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
                    break;
                case 4:
                    type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
                    break;
                default:
                    throw std::runtime_error("elementSize must be 1, 2, or 4");
                }

                int rank = 2;

                CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
                uint32_t contigDimSizeInByte = elementSize * tensorDims[0];
                if (contigDimSizeInByte >= 128) {
                    swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
                } else if (contigDimSizeInByte >= 64) {
                    swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
                } else if (contigDimSizeInByte >= 32) {
                    swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
                } else {
                    throw std::runtime_error("block size too small");
                }

                if (contigDimSizeInByte > 128) {
                    tensorDims[0] = 128 / elementSize;
                }

                CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
                    m, type, rank, globalAddress, dims,
                    globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
                    swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
                    CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
            }
            #endif
        """

    def abi_compatible_header(self):
        return "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"

    def cpp_stream_type(self):
        return "cudaStream_t"

    def aoti_get_stream(self):
        return "aoti_torch_get_current_cuda_stream"

    def cpp_kernel_type(self):
        return "CUfunction"

    def cpp_device_ptr(self):
        return "CUdeviceptr"


register_device_op_overrides("cuda", CUDADeviceOpOverrides())