1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
|
# mypy: allow-untyped-defs
import torch
from ..common import DeviceOpOverrides, register_device_op_overrides
class CUDADeviceOpOverrides(DeviceOpOverrides):
def import_get_raw_stream_as(self, name):
return f"from torch._C import _cuda_getCurrentRawStream as {name}"
def set_device(self, device_idx):
return f"torch.cuda.set_device({device_idx})"
def synchronize(self):
return "torch.cuda.synchronize()"
def device_guard(self, device_idx):
return f"torch.cuda._DeviceGuard({device_idx})"
def cpp_device_guard(self):
return "at::cuda::CUDAGuard"
def cpp_aoti_device_guard(self):
return "AOTICudaGuard"
def cpp_stream_guard(self):
return "at::cuda::CUDAStreamGuard"
def cpp_aoti_stream_guard(self):
return "AOTICudaStreamGuard"
def cpp_getStreamFromExternal(self):
return "at::cuda::getStreamFromExternal"
def kernel_header(self):
source_codes = """
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/EmptyTensor.h>
"""
return source_codes
def kernel_driver(self):
source_codes = """
#define CUDA_DRIVER_CHECK(EXPR) \\
do { \\
CUresult code = EXPR; \\
const char *msg; \\
CUresult code_get_error = cuGetErrorString(code, &msg); \\
if (code_get_error != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string("invalid error code!")); \\
} \\
if (code != CUDA_SUCCESS) { \\
throw std::runtime_error( \\
std::string("CUDA driver error: ") + \\
std::string(msg)); \\
} \\
} while (0);
namespace {
struct Grid {
Grid(uint32_t x, uint32_t y, uint32_t z)
: grid_x(x), grid_y(y), grid_z(z) {}
uint32_t grid_x;
uint32_t grid_y;
uint32_t grid_z;
bool is_non_zero() {
return grid_x > 0 && grid_y > 0 && grid_z > 0;
}
};
} // anonymous namespace
static inline CUfunction loadKernel(
std::string filePath,
const std::string &funcName,
uint32_t sharedMemBytes,
const std::optional<std::string> &cubinDir = std::nullopt) {
if (cubinDir) {
std::filesystem::path p1{*cubinDir};
std::filesystem::path p2{filePath};
filePath = (p1 / p2.filename()).string();
}
CUmodule mod;
CUfunction func;
CUDA_DRIVER_CHECK(cuModuleLoad(&mod, filePath.c_str()));
CUDA_DRIVER_CHECK(cuModuleGetFunction(&func, mod, funcName.c_str()));
if (sharedMemBytes > 0) {
CUDA_DRIVER_CHECK(cuFuncSetAttribute(
func,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
sharedMemBytes
))
}
return func;
}
static inline void launchKernel(
CUfunction func,
uint32_t gridX,
uint32_t gridY,
uint32_t gridZ,
uint32_t numWarps,
uint32_t sharedMemBytes,
void* args[],
cudaStream_t stream) {
CUDA_DRIVER_CHECK(cuLaunchKernel(
func, gridX, gridY, gridZ, 32*numWarps, 1, 1, sharedMemBytes, stream, args, nullptr
));
}
"""
if torch.version.hip is not None:
# Adjusting the warp size to GPU supported wavefront size on AMD GPU
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
source_codes = source_codes.replace(
"32*numWarps", str(prop.warp_size) + "*numWarps"
)
return source_codes
def tma_descriptor_helpers(self):
if torch.version.hip is not None:
raise RuntimeError("Host-side TMA descriptors not supported on HIP.")
# helper functions for initializing 1D and 2D TMA descriptors in C++. borrowed from the Triton code here:
# https://github.com/triton-lang/triton/blob/6af4f88591c85de079d8a36a4d7dba67918e2b39/third_party/nvidia/backend/driver.c#L283
return """
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000
[[maybe_unused]] static void init1DTMADescriptor(
CUtensorMap* m,
void* globalAddress,
uint64_t dim,
uint32_t blockDim,
uint32_t elementSize) {
uint64_t dims[1] = {dim};
uint64_t globalStrides[1] = {dim * elementSize};
uint32_t tensorDims[1] = {blockDim};
uint32_t elementStrides[1] = {1};
CUtensorMapDataType type;
switch (elementSize) {
case 1:
type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case 2:
type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 4:
type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
default:
throw std::runtime_error("elementSize must be 1, 2, or 4");
}
if (elementSize * blockDim < 32) {
throw std::runtime_error("block size too small");
}
int rank = 1;
CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
m, type, rank, globalAddress, dims,
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
CU_TENSOR_MAP_SWIZZLE_NONE, CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
[[maybe_unused]] static void init2DTMADescriptor(
CUtensorMap* m,
void* globalAddress,
uint64_t dim1,
uint64_t dim0,
uint32_t blockDim1,
uint32_t blockDim0,
uint32_t elementSize) {
uint64_t dims[2] = {dim0, dim1};
uint32_t tensorDims[2] = {blockDim0, blockDim1};
uint64_t globalStrides[2] = {dims[0] * elementSize,
dims[0] * dims[1] * elementSize};
uint32_t elementStrides[2] = {1, 1};
CUtensorMapDataType type;
switch (elementSize) {
case 1:
type = CU_TENSOR_MAP_DATA_TYPE_UINT8;
break;
case 2:
type = CU_TENSOR_MAP_DATA_TYPE_UINT16;
break;
case 4:
type = CU_TENSOR_MAP_DATA_TYPE_UINT32;
break;
default:
throw std::runtime_error("elementSize must be 1, 2, or 4");
}
int rank = 2;
CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
uint32_t contigDimSizeInByte = elementSize * tensorDims[0];
if (contigDimSizeInByte >= 128) {
swizzle = CU_TENSOR_MAP_SWIZZLE_128B;
} else if (contigDimSizeInByte >= 64) {
swizzle = CU_TENSOR_MAP_SWIZZLE_64B;
} else if (contigDimSizeInByte >= 32) {
swizzle = CU_TENSOR_MAP_SWIZZLE_32B;
} else {
throw std::runtime_error("block size too small");
}
if (contigDimSizeInByte > 128) {
tensorDims[0] = 128 / elementSize;
}
CUDA_DRIVER_CHECK(cuTensorMapEncodeTiled(
m, type, rank, globalAddress, dims,
globalStrides, tensorDims, elementStrides, CU_TENSOR_MAP_INTERLEAVE_NONE,
swizzle, CU_TENSOR_MAP_L2_PROMOTION_L2_128B,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
#endif
"""
def abi_compatible_header(self):
return "#include <torch/csrc/inductor/aoti_runtime/utils_cuda.h>"
def cpp_stream_type(self):
return "cudaStream_t"
def aoti_get_stream(self):
return "aoti_torch_get_current_cuda_stream"
def cpp_kernel_type(self):
return "CUfunction"
def cpp_device_ptr(self):
return "CUdeviceptr"
register_device_op_overrides("cuda", CUDADeviceOpOverrides())
|