1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
|
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <torch/torch.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <algorithm>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
std::vector<at::Tensor> getValues(
Node* node,
const ValueToParamPairMap& valsToParamsMap) {
size_t numInputs = node->inputs().size();
std::vector<at::Tensor> inputTensorValues;
inputTensorValues.reserve(numInputs);
for (auto val : node->inputs()) {
if (val->node()->kind() == prim::Param) {
auto itr = valsToParamsMap.find(val);
if (itr == valsToParamsMap.end()) {
continue;
}
inputTensorValues.push_back(itr->second.second.toTensor());
} else if (val->node()->kind() == onnx::Constant) {
inputTensorValues.push_back(val->node()->t(attr::value));
} else {
continue;
}
}
return inputTensorValues;
}
// This pass fuses Conv and BatchNorm into Conv node
// Conv and BatchNorm can be fused only if inputs for BatchNorm node:
// scale, bias, mean and var are all tensors of same shape (C) and
// if the size of the first dimension (dim 0) is the same between Conv
// input weight and BatchNorm input scale.
static void fuseConvBatchNorm(Block* b, ValueToParamPairMap& valsToParamsMap) {
for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
for (auto* child_block : it->blocks()) {
fuseConvBatchNorm(child_block, valsToParamsMap);
}
if (it->kind() == onnx::Conv) {
auto oldConv = *it;
if (oldConv->outputs().at(0)->uses().size() != 1) {
continue;
}
auto bnNode = oldConv->outputs().at(0)->uses()[0].user;
if (bnNode->kind() != onnx::BatchNormalization) {
continue;
}
if (oldConv->outputs().size() !=
bnNode->outputs().size()) { // BN layer is not in eval mode
continue;
}
auto epsilon = bnNode->f(attr::epsilon);
auto convInputVals = getValues(oldConv, valsToParamsMap);
if (convInputVals.size() < 1 ||
(oldConv->inputs().size() == 3 && convInputVals.size() != 2)) {
continue;
}
auto bnInputVals = getValues(bnNode, valsToParamsMap);
if (bnInputVals.size() != 4) {
continue;
}
// See
// https://github.com/onnx/onnx/blob/master/docs/Operators.md#BatchNormalization
auto bnScale = bnInputVals[0].clone();
auto bnB = bnInputVals[1].clone();
auto bnMean = bnInputVals[2].clone();
auto bnVar = bnInputVals[3].clone();
// See https://github.com/onnx/onnx/blob/master/docs/Operators.md#Conv
auto convW = convInputVals[0].clone();
at::Tensor convB;
if (!bnScale.is_floating_point() || !bnB.is_floating_point() ||
!bnMean.is_floating_point() || !bnVar.is_floating_point() ||
!convW.is_floating_point() || bnScale.dim() != 1 || bnB.dim() != 1 ||
bnMean.dim() != 1 || bnVar.dim() != 1 ||
!(bnScale.size(0) == bnB.size(0)) ||
!(bnB.size(0) == bnMean.size(0)) ||
!(bnMean.size(0) == bnVar.size(0)) || !(convW.dim() > 2) ||
!(convW.size(0) == bnScale.size(0))) {
continue;
}
bnVar = bnVar.add(epsilon);
bnVar = bnVar.sqrt();
bnScale = bnScale.div(bnVar);
// Calculate weight
for (const auto i : c10::irange(convW.size(0))) {
convW[i] = convW[i].mul(bnScale[i]);
}
// Calculate bias
if (oldConv->inputs().size() == 3) {
convB = convInputVals[1].clone();
convB = convB.sub(bnMean);
convB = convB.mul(bnScale);
convB = convB.add(bnB);
} else {
bnMean = bnMean.mul(bnScale);
bnB = bnB.sub(bnMean);
convB = bnB;
}
Node* newConv = b->owningGraph()->create(onnx::Conv, 1);
newConv->outputs().at(0)->copyMetadata(bnNode->outputs().at(0));
newConv->copyAttributes(*oldConv);
newConv->insertBefore(bnNode);
newConv->addInput(oldConv->inputs().at(0));
newConv->copyMetadata(oldConv);
auto newConvW = b->owningGraph()->addInput();
valsToParamsMap.insert(
{newConvW, std::make_pair(newConvW->debugName(), convW)});
newConvW->inferTypeFrom(convW);
newConv->addInput(newConvW);
auto newConvB = b->owningGraph()->addInput();
valsToParamsMap.insert(
{newConvB, std::make_pair(newConvB->debugName(), convB)});
newConvB->inferTypeFrom(convB);
newConv->addInput(newConvB);
bnNode->outputs().at(0)->replaceAllUsesWith(newConv->outputs().at(0));
bnNode->destroy();
it.destroyCurrent();
}
}
}
void EvalPeepholeONNX(Block* b, ParamMap& paramsDict) {
auto valsToParamsMap = buildValueToParamsMap(b, paramsDict);
fuseConvBatchNorm(b, valsToParamsMap);
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
}
void EvalPeepholeONNX(std::shared_ptr<Graph>& g, ParamMap& paramsDict) {
EvalPeepholeONNX(g->block(), paramsDict);
GRAPH_DUMP("After EvalPeepholeONNX:", g);
}
} // namespace jit
} // namespace torch
|