1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116
|
#include "caffe2/opt/distributed.h"
#include "caffe2/opt/converter.h"
namespace caffe2 {
using namespace nom::repr;
void setDeviceOption(NNGraph::NodeRef n, caffe2::DeviceOption& d) {
getOrAddCaffe2Annotation(n);
auto op = nn::get<NeuralNetOperator>(n);
auto c2Annot = dyn_cast<caffe2::Caffe2Annotation>(op->getMutableAnnotation());
CAFFE_ENFORCE(c2Annot, "getOrAddCaffe2Annotation failed!");
c2Annot->setDeviceOption(d);
}
void addBlobDeviceOptions(
std::map<std::string, caffe2::DeviceOption> blobMap,
NNModule* nn) {
// Names we've seen in the NNModule. Uniqueness within inputs or outputs is ensured
// but same blob can exist across inputs and outputs
std::unordered_set<std::string> seen_inputs;
std::unordered_set<std::string> seen_outputs;
std::unordered_set<std::string> seen;
auto declareNodes = nn::filter<Declare>(*nn);
for (auto& declareNode : declareNodes) {
auto inputNode = nn::getOutputs(declareNode).at(0);
auto input = nn::get<nom::repr::Tensor>(inputNode);
if (!blobMap.count(input->getName())) {
continue;
}
CAFFE_ENFORCE(
!seen_inputs.count(input->getName()),
"Ambiguous name->deviceOption map. Please do this manually. Affected blob: " + input->getName());
seen_inputs.insert(input->getName());
seen.insert(input->getName());
setDeviceOption(declareNode, blobMap[input->getName()]);
}
auto exportNodes = nn::filter<Export>(*nn);
for (auto& exportNode : exportNodes) {
auto outputNode = nn::getInputs(exportNode).at(0);
auto output = nn::get<nom::repr::Tensor>(outputNode);
if (!blobMap.count(output->getName())) {
continue;
}
CAFFE_ENFORCE(
!seen_outputs.count(output->getName()),
"Ambiguous name->deviceOption map. Please do this manually. Affected blob: " + output->getName());
seen_outputs.insert(output->getName());
seen.insert(output->getName());
setDeviceOption(exportNode, blobMap[output->getName()]);
}
if (seen.size() != blobMap.size()) {
std::ostringstream os;
for (const auto& kv : blobMap) {
if (!(seen.count(kv.first))) {
os << "\"" << kv.first << "\" ";
}
}
CAFFE_ENFORCE(
seen.size() == blobMap.size(),
"Unused names in the blob map: ",
os.str());
}
}
void injectDataEdgeIndicators(nom::repr::NNModule* nn) {
for (auto& input : nn->inputs) {
auto declareNode =
nn->dataFlow.createNode(std::make_unique<Declare>());
nn->dataFlow.createEdge(declareNode, input);
}
for (auto& output : nn->outputs) {
auto exportNode = nn->dataFlow.createNode(std::make_unique<Export>());
nn->dataFlow.createEdge(output, exportNode);
}
nn->inputs.clear();
nn->outputs.clear();
}
void removeDataEdgeIndicators(nom::repr::NNModule* nn) {
auto declareNodes = nn::filter<Declare>(*nn);
for (auto& declareNode : declareNodes) {
auto input = nn::getOutputs(declareNode).at(0);
nn->inputs.insert(input);
nn->dataFlow.deleteNode(declareNode);
}
auto exportNodes = nn::filter<Export>(*nn);
for (auto& exportNode : exportNodes) {
auto output = nn::getInputs(exportNode).at(0);
nn->outputs.insert(output);
nn->dataFlow.deleteNode(exportNode);
}
}
nom::repr::NNModule convertToNNModule(
caffe2::NetDef& net,
std::map<std::string, caffe2::DeviceOption> blobMap) {
auto nn = convertToNNModule(net);
injectDataEdgeIndicators(&nn);
addBlobDeviceOptions(blobMap, &nn);
return nn;
}
} // namespace caffe2
|