1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88
|
#include <torch/csrc/jit/operator_upgraders/upgraders.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
#include <mutex>
#include <string>
#include <unordered_map>
namespace torch {
namespace jit {
static UpgradersMap upgradersMap;
void UpgradersMap::set_content(
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
// make sure we populate the map only once
std::lock_guard<std::mutex> _(lock);
if (isPopulated) {
return;
}
content_ = std::move(content);
isPopulated = true;
}
int UpgradersMap::count() {
std::lock_guard<std::mutex> _(lock);
return content_.size();
}
bool UpgradersMap::is_populated() {
std::lock_guard<std::mutex> _(lock);
return isPopulated;
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>& UpgradersMap::
get_content() {
std::lock_guard<std::mutex> _(lock);
return content_;
}
void UpgradersMap::test_only_set_content(
const std::unordered_map<std::string, std::string>& content) {
std::lock_guard<std::mutex> _(lock);
for (const auto& entry : content) {
auto graph = std::make_shared<Graph>();
torch::jit::parseIR(entry.second, graph.get());
content_.insert(std::make_pair(entry.first, graph));
}
}
void UpgradersMap::test_only_remove_content(
const std::unordered_map<std::string, std::string>& content) {
std::lock_guard<std::mutex> _(lock);
for (const auto& entry : content) {
content_.erase(entry.first);
}
}
void populate_upgraders_map(
std::unordered_map<std::string, std::shared_ptr<Graph>>&& content) {
upgradersMap.set_content(std::move(content));
}
int get_upgraders_map_size() {
return upgradersMap.count();
}
bool is_upgraders_map_populated() {
return upgradersMap.is_populated();
}
const std::unordered_map<std::string, std::shared_ptr<Graph>>&
dump_upgraders_map() {
return upgradersMap.get_content();
}
void test_only_populate_upgraders(
const std::unordered_map<std::string, std::string>& content) {
upgradersMap.test_only_set_content(content);
}
void test_only_remove_upgraders(
const std::unordered_map<std::string, std::string>& content) {
upgradersMap.test_only_remove_content(content);
}
} // namespace jit
} // namespace torch
|