File: anomaly_mode.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (79 lines) | stat: -rw-r--r-- 2,306 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#include <c10/util/Backtrace.h>
#include <c10/util/Exception.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <mutex>

namespace torch {
namespace autograd {

bool AnomalyMode::_enabled = false;
bool AnomalyMode::_check_nan = true;

namespace {
std::mutex& get_anomaly_guard_lock() {
  static std::mutex anomaly_guard_lock{};
  return anomaly_guard_lock;
}

uint32_t& get_anomaly_counter() {
  static uint32_t counter = 0;
  return counter;
}
} // namespace

DetectAnomalyGuard::DetectAnomalyGuard(bool check_nan) {
  TORCH_WARN_ONCE(
      "This mode should be enabled only for debugging as the different tests will slow down your program execution.");
  std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
  uint32_t& counter = get_anomaly_counter();
  counter++;
  this->prev_check_nan_ = AnomalyMode::should_check_nan();
  AnomalyMode::set_enabled(true, check_nan);
}

DetectAnomalyGuard::~DetectAnomalyGuard() {
  std::lock_guard<std::mutex> lock(get_anomaly_guard_lock());
  uint32_t& counter = get_anomaly_counter();
  counter--;
  AnomalyMode::set_enabled(counter > 0, this->prev_check_nan_);
}

AnomalyMetadata::~AnomalyMetadata() = default;

void AnomalyMetadata::store_stack() {
  traceback_ = c10::get_backtrace(/* frames_to_skip */ 1);
}

void AnomalyMetadata::print_stack(const std::string& current_node_name) {
  TORCH_WARN(
      "Error detected in ",
      current_node_name,
      ". ",
      "Traceback of forward call that caused the error:\n",
      traceback_);

  auto& cur_parent = parent_;
  // if there is no "parent_" in metadata, then it means this metadata's node
  // is the root and stop printing the traceback
  while (cur_parent) {
    auto parent_metadata = cur_parent->metadata();
    TORCH_WARN(
        "\n\n",
        "Previous calculation was induced by ",
        cur_parent->name(),
        ". "
        "Traceback of forward call that induced the previous calculation:\n",
        parent_metadata->traceback_);
    // get the parent of this node, if this node is a root, pyparent is simply
    // null
    cur_parent = parent_metadata->parent_;
  }
}

void AnomalyMetadata::assign_parent(const std::shared_ptr<Node>& parent_node) {
  parent_ = parent_node;
}

} // namespace autograd
} // namespace torch