File: diagnostics.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (63 lines) | stat: -rw-r--r-- 1,650 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once
#include <torch/csrc/onnx/diagnostics/generated/rules.h>
#include <torch/csrc/utils/pybind.h>
#include <string>

namespace torch::onnx::diagnostics {

/**
 * @brief Level of a diagnostic.
 * @details The levels are defined by the SARIF specification, and are not
 * modifiable. For alternative categories, please use Tag instead.
 * @todo Introduce Tag to C++ api.
 */
enum class Level : uint8_t {
  kNone,
  kNote,
  kWarning,
  kError,
};

// NOLINTNEXTLINE(*array*)
static constexpr const char* const kPyLevelNames[] = {
    "NONE",
    "NOTE",
    "WARNING",
    "ERROR",
};

// Wrappers around Python diagnostics.
// TODO: Move to .cpp file in following PR.

inline py::object _PyDiagnostics() {
  return py::module::import("torch.onnx._internal.diagnostics");
}

inline py::object _PyRule(Rule rule) {
  return _PyDiagnostics().attr("rules").attr(
      kPyRuleNames[static_cast<uint32_t>(rule)]);
}

inline py::object _PyLevel(Level level) {
  return _PyDiagnostics().attr("levels").attr(
      kPyLevelNames[static_cast<uint32_t>(level)]);
}

inline void Diagnose(
    Rule rule,
    Level level,
    std::unordered_map<std::string, std::string> messageArgs = {}) {
  py::object py_rule = _PyRule(rule);
  py::object py_level = _PyLevel(level);

  // TODO: statically check that size of messageArgs matches with rule.
  py::object py_message =
      py_rule.attr("format_message")(**py::cast(messageArgs));

  // to use the `_a` literal for arguments
  using namespace pybind11::literals;
  _PyDiagnostics().attr("diagnose")(
      py_rule, py_level, py_message, "cpp_stack"_a = true);
}

} // namespace torch::onnx::diagnostics