1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
#include <utility>
#include "pybind11/pybind11.h"
#include "pybind11/cast.h"
#include "pybind11/stl.h"
#include "cudnn_frontend.h"
namespace py = pybind11;
using namespace pybind11::literals;
namespace cudnn_frontend {
void *cudnn_dlhandle = nullptr;
namespace python_bindings {
// Raise C++ exceptions corresponding to C++ FE error codes.
// Pybinds will automatically convert C++ exceptions to python exceptions.
void
throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) {
if (cond == false) return;
switch (error_code) {
case cudnn_frontend::error_code_t::OK:
return;
case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK:
throw std::invalid_argument(error_msg);
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::CUDA_API_FAILED:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED:
throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
case cudnn_frontend::error_code_t::HANDLE_ERROR:
throw std::runtime_error(error_msg);
case cudnn_frontend::error_code_t::INVALID_VALUE:
throw std::runtime_error(error_msg);
}
}
// pybinds for pygraph class
void
init_pygraph_submodule(py::module_ &);
// pybinds for kernel_cache class
void
create_kernel_cache_submodule(py::module_ &);
// pybinds for all properties and helpers
void
init_properties(py::module_ &);
void
set_dlhandle_cudnn(std::intptr_t dlhandle) {
cudnn_dlhandle = reinterpret_cast<void *>(dlhandle);
}
PYBIND11_MODULE(_compiled_module, m) {
m.def("backend_version", &detail::get_backend_version);
m.def("backend_version_string", &detail::get_backend_version_string);
init_properties(m);
init_pygraph_submodule(m);
m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn);
py::register_exception<cudnnGraphNotSupportedException>(m, "cudnnGraphNotSupportedError");
}
} // namespace python_bindings
} // namespace cudnn_frontend
|