1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
|
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
#include <torch/csrc/jit/passes/onnx/helper.h>
#include <c10/util/irange.h>
namespace torch {
namespace jit {
namespace onnx {
using namespace ::c10::onnx;
}
void DeduplicateInitializers(
std::shared_ptr<Graph>& g,
ValueToParamPairMap& valsToParamsMap,
bool (*comp)(at::Tensor&, at::Tensor&)) {
auto is_same_tensor_as = [&valsToParamsMap, comp](Value* v1) {
return [&valsToParamsMap, v1, comp](Value* v2) {
if ((valsToParamsMap.find(v1) == valsToParamsMap.end()) ||
(valsToParamsMap.find(v2) == valsToParamsMap.end())) {
return false;
}
auto iv1 = valsToParamsMap.find(v1)->second.second;
auto iv2 = valsToParamsMap.find(v2)->second.second;
if (!iv1.isTensor() || !iv2.isTensor()) {
return false;
}
auto t1 = iv1.toTensor();
auto t2 = iv2.toTensor();
return comp(t1, t2);
};
};
std::vector<Value*> uniqueVals;
std::vector<size_t> inputsIndicesToRemove;
auto b = g->block();
for (auto i : c10::irange(b->inputs().size())) {
auto v = g->inputs().at(i);
if (valsToParamsMap.find(v) == valsToParamsMap.end()) {
// Skip model inputs
continue;
}
auto it = std::find_if(
uniqueVals.begin(), uniqueVals.end(), is_same_tensor_as(v));
if (it == uniqueVals.end()) {
uniqueVals.emplace_back(v);
} else {
inputsIndicesToRemove.emplace_back(i);
auto id_node = g->create(onnx::Identity);
id_node->insertAfter(g->block()->param_node());
id_node->addInput(*it);
id_node->output()->copyMetadata(v);
id_node->copyMetadata(g->block()->param_node());
v->replaceAllUsesWith(id_node->output());
}
}
for (auto it = inputsIndicesToRemove.rbegin();
it != inputsIndicesToRemove.rend();
++it) {
valsToParamsMap.erase(g->inputs().at(*it));
g->eraseInput(*it);
}
}
bool DeduplicateInitializersByDataPtr(at::Tensor& t1, at::Tensor& t2) {
return t1.sizes().equals(t2.sizes()) && t1.strides().equals(t2.strides()) &&
(t1.has_storage() && t2.has_storage() && t1.data_ptr() == t2.data_ptr());
}
bool DeduplicateInitializersByValue(at::Tensor& t1, at::Tensor& t2) {
if (t1.dtype() != t2.dtype() || !t1.sizes().equals(t2.sizes()) ||
!t1.strides().equals(t2.strides())) {
return false;
}
if (t1.device() != t2.device()) {
return t1.to("cpu").equal(t2.to("cpu"));
}
return t1.equal(t2);
}
void DeduplicateInitializers(
std::shared_ptr<Graph>& g,
std::map<std::string, IValue>& paramsDict,
bool is_train) {
auto valsToParamsMap = buildValueToParamsMap(g->block(), paramsDict);
// ONNX spec does not support parameters with shared memory.
// This pass de-duplicate those parameters. Training is not affected.
DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByDataPtr);
if (!is_train) {
// More aggressive parameters de-duplication based on tensor values.
// Producing more compact model for inference.
// For training, this pass is disabled,
// because parameters may be updated differently.
DeduplicateInitializers(g, valsToParamsMap, DeduplicateInitializersByValue);
}
buildParamsMapFromValueToParamsMap(valsToParamsMap, paramsDict);
}
} // namespace jit
} // namespace torch
|