File: pycudnn.cpp

package info (click to toggle)
nvidia-cudnn-frontend 1.8.0%2Bds-1
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 4,376 kB
  • sloc: cpp: 58,463; python: 4,138; ansic: 1,407; makefile: 4
file content (88 lines) | stat: -rw-r--r-- 3,223 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#include <utility>

#include "pybind11/pybind11.h"
#include "pybind11/cast.h"
#include "pybind11/stl.h"

#include "cudnn_frontend.h"

namespace py = pybind11;
using namespace pybind11::literals;

namespace cudnn_frontend {

void *cudnn_dlhandle = nullptr;

namespace python_bindings {

// Raise C++ exceptions corresponding to C++ FE error codes.
// Pybinds will automatically convert C++ exceptions to python exceptions.
void
throw_if(bool const cond, cudnn_frontend::error_code_t const error_code, std::string const &error_msg) {
    if (cond == false) return;

    switch (error_code) {
        case cudnn_frontend::error_code_t::OK:
            return;
        case cudnn_frontend::error_code_t::ATTRIBUTE_NOT_SET:
            throw std::invalid_argument(error_msg);
        case cudnn_frontend::error_code_t::SHAPE_DEDUCTION_FAILED:
            throw std::invalid_argument(error_msg);
        case cudnn_frontend::error_code_t::INVALID_TENSOR_NAME:
            throw std::invalid_argument(error_msg);
        case cudnn_frontend::error_code_t::INVALID_VARIANT_PACK:
            throw std::invalid_argument(error_msg);
        case cudnn_frontend::error_code_t::GRAPH_EXECUTION_PLAN_CREATION_FAILED:
            throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
        case cudnn_frontend::error_code_t::GRAPH_EXECUTION_FAILED:
            throw std::runtime_error(error_msg);
        case cudnn_frontend::error_code_t::HEURISTIC_QUERY_FAILED:
            throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
        case cudnn_frontend::error_code_t::CUDNN_BACKEND_API_FAILED:
            throw std::runtime_error(error_msg);
        case cudnn_frontend::error_code_t::CUDA_API_FAILED:
            throw std::runtime_error(error_msg);
        case cudnn_frontend::error_code_t::INVALID_CUDA_DEVICE:
            throw std::runtime_error(error_msg);
        case cudnn_frontend::error_code_t::UNSUPPORTED_GRAPH_FORMAT:
            throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
        case cudnn_frontend::error_code_t::GRAPH_NOT_SUPPORTED:
            throw cudnn_frontend::cudnnGraphNotSupportedException(error_msg.c_str());
        case cudnn_frontend::error_code_t::HANDLE_ERROR:
            throw std::runtime_error(error_msg);
        case cudnn_frontend::error_code_t::INVALID_VALUE:
            throw std::runtime_error(error_msg);
    }
}

// pybinds for pygraph class
void
init_pygraph_submodule(py::module_ &);

// pybinds for kernel_cache class
void
create_kernel_cache_submodule(py::module_ &);

// pybinds for all properties and helpers
void
init_properties(py::module_ &);

void
set_dlhandle_cudnn(std::intptr_t dlhandle) {
    cudnn_dlhandle = reinterpret_cast<void *>(dlhandle);
}

PYBIND11_MODULE(_compiled_module, m) {
    m.def("backend_version", &detail::get_backend_version);
    m.def("backend_version_string", &detail::get_backend_version_string);

    init_properties(m);
    init_pygraph_submodule(m);

    m.def("_set_dlhandle_cudnn", &set_dlhandle_cudnn);

    py::register_exception<cudnnGraphNotSupportedException>(m, "cudnnGraphNotSupportedError");
}

}  // namespace python_bindings
}  // namespace cudnn_frontend