1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
|
#include <c10/util/error.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h>
#if defined(USE_CUFILE)
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <cufile.h>
namespace {
// To get error message for cuFileRead/Write APIs that return ssize_t (-1 for
// filesystem error and a negative CUfileOpError enum value otherwise).
template <
class T,
std::enable_if_t<std::is_integral_v<T>, std::nullptr_t> = nullptr>
std::string cuGDSFileGetErrorString(T status) {
status = std::abs(status);
return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status))
: std::string(c10::utils::str_error(errno));
}
// To get error message for Buf/Handle registeration APIs that return
// CUfileError_t
template <
class T,
std::enable_if_t<!std::is_integral_v<T>, std::nullptr_t> = nullptr>
std::string cuGDSFileGetErrorString(T status) {
std::string errStr = cuGDSFileGetErrorString(static_cast<int>(status.err));
if (IS_CUDA_ERR(status))
errStr.append(".").append(
cudaGetErrorString(static_cast<cudaError_t>(status.cu_err)));
return errStr;
}
} // namespace
void gds_load_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Read the binary file
ssize_t ret = cuFileRead(cf_handle, (void*)dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileRead failed: ", cuGDSFileGetErrorString(ret));
}
void gds_save_storage(
int64_t handle,
const at::Storage& storage,
off_t offset) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
c10::cuda::CUDAGuard gpuGuard(storage.device());
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
// Write device memory contents to the file
ssize_t ret = cuFileWrite(cf_handle, dataPtr, nbytes, offset, 0);
TORCH_CHECK(ret >= 0, "cuFileWrite failed: ", cuGDSFileGetErrorString(ret));
}
void gds_register_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
const size_t nbytes = storage.nbytes();
CUfileError_t status = cuFileBufRegister(dataPtr, nbytes, 0);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufRegister failed: ",
cuGDSFileGetErrorString(status));
return;
}
void gds_deregister_buffer(const at::Storage& storage) {
void* dataPtr = storage.mutable_data();
CUfileError_t status = cuFileBufDeregister(dataPtr);
TORCH_CHECK(
status.err == CU_FILE_SUCCESS,
"cuFileBufDeregister failed: ",
cuGDSFileGetErrorString(status));
return;
}
int64_t gds_register_handle(int fd) {
CUfileDescr_t cf_descr;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
CUfileHandle_t cf_handle;
memset((void*)&cf_descr, 0, sizeof(CUfileDescr_t));
cf_descr.handle.fd = fd;
cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD;
CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr);
if (status.err != CU_FILE_SUCCESS) {
TORCH_CHECK(
false,
"cuFileHandleRegister failed: ",
cuGDSFileGetErrorString(status));
}
// Returning cuFileHandle_t as int64_t
return reinterpret_cast<int64_t>(cf_handle);
}
void gds_deregister_handle(int64_t handle) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
CUfileHandle_t cf_handle = reinterpret_cast<CUfileHandle_t>(handle);
cuFileHandleDeregister(cf_handle);
}
#endif
namespace torch::cuda::shared {
void initGdsBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
#if defined(USE_CUFILE)
m.def("_gds_register_handle", &gds_register_handle);
m.def("_gds_deregister_handle", &gds_deregister_handle);
m.def("_gds_register_buffer", &gds_register_buffer);
m.def("_gds_deregister_buffer", &gds_deregister_buffer);
m.def("_gds_load_storage", &gds_load_storage);
m.def("_gds_save_storage", &gds_save_storage);
#endif
}
} // namespace torch::cuda::shared
|