1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134
|
// (c) Facebook, Inc. and its affiliates. Confidential and proprietary.
#include <map>
#include <c10/util/string_view.h>
#include <c10/util/Flags.h>
#include <c10/util/Logging.h>
#include "caffe2/utils/knobs.h"
#include "caffe2/utils/knob_patcher.h"
namespace caffe2 {
namespace detail {
std::map<c10::string_view, bool*>& getRegisteredKnobs();
} // namespace detail
namespace {
class PatchNode {
public:
PatchNode(c10::string_view name, bool value);
~PatchNode();
std::string name;
bool oldValue{false};
// Nodes to form a linked list of existing PatchState objects for this knob.
// This allows us to restore state correctly even if KnobPatcher objects
// are destroyed in any arbitrary order.
PatchNode* prev{nullptr};
PatchNode* next{nullptr};
};
} // namespace
class KnobPatcher::PatchState : public PatchNode {
using PatchNode::PatchNode;
};
KnobPatcher::KnobPatcher(c10::string_view name, bool value)
: state_{std::make_unique<PatchState>(name, value)} {}
KnobPatcher::~KnobPatcher() = default;
KnobPatcher::KnobPatcher(KnobPatcher&&) noexcept = default;
KnobPatcher& KnobPatcher::operator=(KnobPatcher&&) noexcept = default;
namespace {
class Patcher {
public:
void patch(PatchNode* node, bool value) {
std::lock_guard<std::mutex> lock{mutex_};
node->oldValue = setKnobValue(node->name, value);
auto ret = patches_.emplace(node->name, node);
if (!ret.second) {
// There was already another patcher for this knob
// Append the new node to the linked list.
node->prev = ret.first->second;
CHECK(!node->prev->next);
node->prev->next = node;
ret.first->second = node;
}
}
void unpatch(PatchNode* node) {
std::lock_guard<std::mutex> lock{mutex_};
// Remove this PatchNode from the linked list
if (node->prev) {
node->prev->next = node->next;
}
if (node->next) {
// There was another patch applied after this one.
node->next->prev = node->prev;
node->next->oldValue = node->oldValue;
} else {
// This was the most recently applied patch for this knob,
// so restore the knob value.
setKnobValue(node->name, node->oldValue);
// The patches_ map should point to this node.
// Update it to point to the previous patch, if there is one.
auto iter = patches_.find(node->name);
if (iter == patches_.end()) {
LOG(FATAL) << "patch node not found when unpatching knob value";
}
TORCH_CHECK_EQ(iter->second, node);
if (node->prev) {
iter->second = node->prev;
} else {
patches_.erase(iter);
}
}
}
private:
bool setKnobValue(c10::string_view name, bool value) {
auto& knobs = caffe2::detail::getRegisteredKnobs();
auto iter = knobs.find(name);
if (iter == knobs.end()) {
throw std::invalid_argument(
"attempted to patch unknown knob \"" + std::string(name) + "\"");
}
bool oldValue = *(iter->second);
*iter->second = value;
return oldValue;
}
std::mutex mutex_;
std::map<std::string, PatchNode*> patches_;
};
Patcher& getPatcher() {
static Patcher patcher;
return patcher;
}
PatchNode::PatchNode(c10::string_view knobName, bool value)
: name{knobName} {
getPatcher().patch(this, value);
}
PatchNode::~PatchNode() {
try {
getPatcher().unpatch(this);
} catch (const std::exception& ex) {
// This shouldn't ever happen unless we have a programming bug, but it keeps
// clang-tidy happy if we put a catch block here to handle the theoretical
// error if unpatch() calls setKnobValue() and it throws due to not finding
// the knob by name.
LOG(FATAL) << "error removing knob patch: " << ex.what();
}
}
} // namespace
} // namespace caffe2
|