File: knob_patcher.cc

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (134 lines) | stat: -rw-r--r-- 3,775 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.

#include <map>

#include <c10/util/string_view.h>
#include <c10/util/Flags.h>
#include <c10/util/Logging.h>

#include "caffe2/utils/knobs.h"
#include "caffe2/utils/knob_patcher.h"

namespace caffe2 {
namespace detail {
std::map<c10::string_view, bool*>& getRegisteredKnobs();
} // namespace detail

namespace {
class PatchNode {
 public:
  PatchNode(c10::string_view name, bool value);
  ~PatchNode();

  std::string name;
  bool oldValue{false};
  // Nodes to form a linked list of existing PatchState objects for this knob.
  // This allows us to restore state correctly even if KnobPatcher objects
  // are destroyed in any arbitrary order.
  PatchNode* prev{nullptr};
  PatchNode* next{nullptr};
};
} // namespace

class KnobPatcher::PatchState : public PatchNode {
  using PatchNode::PatchNode;
};

KnobPatcher::KnobPatcher(c10::string_view name, bool value)
  : state_{std::make_unique<PatchState>(name, value)} {}

KnobPatcher::~KnobPatcher() = default;
KnobPatcher::KnobPatcher(KnobPatcher&&) noexcept = default;
KnobPatcher& KnobPatcher::operator=(KnobPatcher&&) noexcept = default;

namespace {

class Patcher {
 public:
  void patch(PatchNode* node, bool value) {
    std::lock_guard<std::mutex> lock{mutex_};

    node->oldValue = setKnobValue(node->name, value);
    auto ret = patches_.emplace(node->name, node);
    if (!ret.second) {
      // There was already another patcher for this knob
      // Append the new node to the linked list.
      node->prev = ret.first->second;
      CHECK(!node->prev->next);
      node->prev->next = node;
      ret.first->second = node;
    }
  }

  void unpatch(PatchNode* node) {
    std::lock_guard<std::mutex> lock{mutex_};

    // Remove this PatchNode from the linked list
    if (node->prev) {
      node->prev->next = node->next;
    }
    if (node->next) {
      // There was another patch applied after this one.
      node->next->prev = node->prev;
      node->next->oldValue = node->oldValue;
    } else {
      // This was the most recently applied patch for this knob,
      // so restore the knob value.
      setKnobValue(node->name, node->oldValue);

      // The patches_ map should point to this node.
      // Update it to point to the previous patch, if there is one.
      auto iter = patches_.find(node->name);
      if (iter == patches_.end()) {
        LOG(FATAL) << "patch node not found when unpatching knob value";
      }
      TORCH_CHECK_EQ(iter->second, node);
      if (node->prev) {
        iter->second = node->prev;
      } else {
        patches_.erase(iter);
      }
    }
  }

 private:
  bool setKnobValue(c10::string_view name, bool value) {
    auto& knobs = caffe2::detail::getRegisteredKnobs();
    auto iter = knobs.find(name);
    if (iter == knobs.end()) {
      throw std::invalid_argument(
          "attempted to patch unknown knob \"" + std::string(name) + "\"");
    }
    bool oldValue = *(iter->second);
    *iter->second = value;
    return oldValue;
  }

  std::mutex mutex_;
  std::map<std::string, PatchNode*> patches_;
};

Patcher& getPatcher() {
  static Patcher patcher;
  return patcher;
}

PatchNode::PatchNode(c10::string_view knobName, bool value)
    : name{knobName} {
  getPatcher().patch(this, value);
}

PatchNode::~PatchNode() {
  try {
    getPatcher().unpatch(this);
  } catch (const std::exception& ex) {
    // This shouldn't ever happen unless we have a programming bug, but it keeps
    // clang-tidy happy if we put a catch block here to handle the theoretical
    // error if unpatch() calls setKnobValue() and it throws due to not finding
    // the knob by name.
    LOG(FATAL) << "error removing knob patch: " << ex.what();
  }
}

} // namespace
} // namespace caffe2