1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
|
#include <cuda.h>
#include <cuda_runtime.h>
#include <torch/csrc/utils/pybind.h>
#if !defined(USE_ROCM)
#include <cuda_profiler_api.h>
#else
#include <hip/hip_runtime_api.h>
#endif
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAGuard.h>
namespace torch {
namespace cuda {
namespace shared {
#ifdef USE_ROCM
namespace {
hipError_t hipReturnSuccess() {
return hipSuccess;
}
} // namespace
#endif
void initCudartBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cudart = m.def_submodule("_cudart", "libcudart.so bindings");
// By splitting the names of these objects into two literals we prevent the
// HIP rewrite rules from changing these names when building with HIP.
#if !defined(USE_ROCM)
py::enum_<cudaOutputMode_t>(
cudart,
"cuda"
"OutputMode")
.value("KeyValuePair", cudaKeyValuePair)
.value("CSV", cudaCSV);
#endif
py::enum_<cudaError_t>(
cudart,
"cuda"
"Error")
.value("success", cudaSuccess);
cudart.def(
"cuda"
"GetErrorString",
cudaGetErrorString);
cudart.def(
"cuda"
"ProfilerStart",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStart
#endif
);
cudart.def(
"cuda"
"ProfilerStop",
#ifdef USE_ROCM
hipReturnSuccess
#else
cudaProfilerStop
#endif
);
cudart.def(
"cuda"
"HostRegister",
[](uintptr_t ptr, size_t size, unsigned int flags) -> cudaError_t {
return cudaHostRegister((void*)ptr, size, flags);
});
cudart.def(
"cuda"
"HostUnregister",
[](uintptr_t ptr) -> cudaError_t {
return cudaHostUnregister((void*)ptr);
});
cudart.def(
"cuda"
"StreamCreate",
[](uintptr_t ptr) -> cudaError_t {
return cudaStreamCreate((cudaStream_t*)ptr);
});
cudart.def(
"cuda"
"StreamDestroy",
[](uintptr_t ptr) -> cudaError_t {
return cudaStreamDestroy((cudaStream_t)ptr);
});
#if !defined(USE_ROCM)
cudart.def(
"cuda"
"ProfilerInitialize",
cudaProfilerInitialize);
#endif
cudart.def(
"cuda"
"MemGetInfo",
[](int device) -> std::pair<size_t, size_t> {
c10::cuda::CUDAGuard guard(device);
size_t device_free = 0;
size_t device_total = 0;
cudaMemGetInfo(&device_free, &device_total);
return {device_free, device_total};
});
}
} // namespace shared
} // namespace cuda
} // namespace torch
|