File: diagnostics.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (71 lines) | stat: -rw-r--r-- 1,713 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#pragma once
#include <torch/csrc/onnx/diagnostics/generated/rules.h>
#include <torch/csrc/utils/pybind.h>
#include <string>

namespace torch {
namespace onnx {
namespace diagnostics {

/**
 * @brief Level of a diagnostic.
 * @details The levels are defined by the SARIF specification, and are not
 * modifiable. For alternative categories, please use Tag instead.
 * @todo Introduce Tag to C++ api.
 */
enum class Level : uint8_t {
  kNone,
  kNote,
  kWarning,
  kError,
};

static constexpr const char* const kPyLevelNames[] = {
    "NONE",
    "NOTE",
    "WARNING",
    "ERROR",
};

// Wrappers around Python diagnostics.
// TODO: Move to .cpp file in following PR.

inline py::object _PyDiagnostics() {
  return py::module::import("torch.onnx._internal.diagnostics");
}

inline py::object _PyEngine() {
  return _PyDiagnostics().attr("engine");
}

inline py::object _PyContext() {
  return _PyDiagnostics().attr("context");
}

inline py::object _PyRule(Rule rule) {
  return _PyDiagnostics().attr("rules").attr(
      kPyRuleNames[static_cast<uint32_t>(rule)]);
}

inline py::object _PyLevel(Level level) {
  return _PyDiagnostics().attr("levels").attr(
      kPyLevelNames[static_cast<uint32_t>(level)]);
}

inline void Diagnose(
    Rule rule,
    Level level,
    std::vector<std::string> messageArgs = {}) {
  py::object py_rule = _PyRule(rule);
  py::object py_level = _PyLevel(level);
  py::object py_context = _PyContext();

  py::dict kwargs = py::dict();
  // TODO: statically check that size of messageArgs matches with rule.
  kwargs["message_args"] = messageArgs;
  py_context.attr("diagnose")(py_rule, py_level, **kwargs);
}

} // namespace diagnostics
} // namespace onnx
} // namespace torch