1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
|
// The clang-tidy job seems to complain that it can't find cudnn.h without this.
// This file should only be compiled if this condition holds, so it should be
// safe.
#if defined(USE_CUDNN) || defined(USE_ROCM)
#include <torch/csrc/utils/pybind.h>
#include <array>
#include <tuple>
namespace {
using version_tuple = std::tuple<size_t, size_t, size_t>;
}
#ifdef USE_CUDNN
#include <cudnn.h>
namespace {
version_tuple getCompileVersion() {
return version_tuple(CUDNN_MAJOR, CUDNN_MINOR, CUDNN_PATCHLEVEL);
}
version_tuple getRuntimeVersion() {
#ifndef USE_STATIC_CUDNN
auto version = cudnnGetVersion();
auto major = version / 1000;
auto minor = (version % 1000) / 100;
auto patch = version % 10;
return version_tuple(major, minor, patch);
#else
return getCompileVersion();
#endif
}
size_t getVersionInt() {
#ifndef USE_STATIC_CUDNN
return cudnnGetVersion();
#else
return CUDNN_VERSION;
#endif
}
} // namespace
#elif defined(USE_ROCM)
#include <miopen/miopen.h>
#include <miopen/version.h>
namespace {
version_tuple getCompileVersion() {
return version_tuple(
MIOPEN_VERSION_MAJOR, MIOPEN_VERSION_MINOR, MIOPEN_VERSION_PATCH);
}
version_tuple getRuntimeVersion() {
// MIOpen doesn't include runtime version info before 2.3.0
#if (MIOPEN_VERSION_MAJOR > 2) || \
(MIOPEN_VERSION_MAJOR == 2 && MIOPEN_VERSION_MINOR > 2)
size_t major, minor, patch;
miopenGetVersion(&major, &minor, &patch);
return version_tuple(major, minor, patch);
#else
return getCompileVersion();
#endif
}
size_t getVersionInt() {
// miopen version is MAJOR*1000000 + MINOR*1000 + PATCH
size_t major, minor, patch;
std::tie(major, minor, patch) = getRuntimeVersion();
return major * 1000000 + minor * 1000 + patch;
}
} // namespace
#endif
namespace torch {
namespace cuda {
namespace shared {
void initCudnnBindings(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
auto cudnn = m.def_submodule("_cudnn", "libcudnn.so bindings");
py::enum_<cudnnRNNMode_t>(cudnn, "RNNMode")
.value("rnn_relu", CUDNN_RNN_RELU)
.value("rnn_tanh", CUDNN_RNN_TANH)
.value("lstm", CUDNN_LSTM)
.value("gru", CUDNN_GRU);
// The runtime version check in python needs to distinguish cudnn from miopen
#ifdef USE_CUDNN
cudnn.attr("is_cuda") = true;
#else
cudnn.attr("is_cuda") = false;
#endif
cudnn.def("getRuntimeVersion", getRuntimeVersion);
cudnn.def("getCompileVersion", getCompileVersion);
cudnn.def("getVersionInt", getVersionInt);
}
} // namespace shared
} // namespace cuda
} // namespace torch
#endif
|