1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018
|
#include "caffe2/opt/optimize_ideep.h"
#include "caffe2/opt/converter.h"
#ifdef USE_MKLDNN
#include <cpuinfo.h>
#include "caffe2/ideep/ideep_utils.h"
#endif
namespace caffe2 {
namespace opt {
using namespace nom;
#ifndef USE_MKLDNN
void OptimizeForMkldnn(
repr::NNModule* nn,
caffe2::Workspace* ws,
bool training_mode) {
LOG(WARNING) << "Only support optimizations for IDEEP";
}
#else
USE_IDEEP_DEF_ALIASES();
Blob* getBlob(const std::string name, caffe2::Workspace* ws) {
CAFFE_ENFORCE(ws->HasBlob(name), "Blob ", name, " not in workspace");
return ws->GetBlob(name);
}
Blob* getBlob(repr::NNGraph::NodeRef node, caffe2::Workspace* ws) {
auto tensor = repr::nn::get<repr::Tensor>(node);
return getBlob(tensor->getName(), ws);
}
template <class T>
T getTensor(Blob* blob) {
CAFFE_ENFORCE(blob, "Blob is invalid");
return blob->template Get<T>();
}
template <class T>
T* getMutableTensor(Blob* blob) {
CAFFE_ENFORCE(blob, "Blob is invalid");
if (blob && blob->template IsType<T>()) {
return blob->template GetMutable<T>();
}
return nullptr;
}
const caffe2::OperatorDef& getOpDef(const repr::NeuralNetOperator& nnOp) {
auto annotation = nnOp.getAnnotation();
if (annotation == nullptr) {
CAFFE_THROW("Cannot get Operator annotation");
}
return dyn_cast<Caffe2Annotation>(annotation)->getOperatorDef();
}
caffe2::OperatorDef* getMutableOpDef(repr::NeuralNetOperator& nnOp) {
auto annotation = nnOp.getMutableAnnotation();
if (annotation == nullptr) {
CAFFE_THROW("Cannot get Operator annotation");
}
return dyn_cast<Caffe2Annotation>(annotation)->getMutableOperatorDef();
}
bool isOpType(const repr::NNGraph::NodeRef& nodeRef, string typeName) {
if (!repr::nn::is<repr::NeuralNetOperator>(nodeRef)) {
return false;
}
auto op = repr::nn::get<repr::NeuralNetOperator>(nodeRef);
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto opDef = getOpDef(*op);
return opDef.type() == typeName;
}
bool isOnIdeepDevice(const repr::NeuralNetOperator& nnOp) {
// We only want to fuse for IDEEP operators
const auto& op = getOpDef(nnOp);
return op.device_option().device_type() == DeviceTypeProto::PROTO_IDEEP;
}
bool isConvFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
// Here we only check the type of ConvFusion op (for FP32 only)
if (!repr::nn::is<repr::Conv>(convNode)) {
return false;
}
auto conv = repr::nn::get<repr::Conv>(convNode);
auto& op = getOpDef(*conv);
if (op.type() == "ConvFusion") {
for (const auto& arg : op.arg()) {
if (arg.name() == "fusion_type") {
if (fusion_type == FUSION_MAX) {
return true;
}
return arg.i() == fusion_type;
}
}
}
return false;
}
void resetConvForFusion(repr::NNGraph::NodeRef convNode, int fusion_type) {
auto conv = repr::nn::get<repr::Conv>(convNode);
auto* op = getMutableOpDef(*conv);
if (op == nullptr) {
return;
}
if (op->type() == "ConvFusion") {
CAFFE_ENFORCE(fusion_type == FUSION_CONV_RELU, "Invalid nest fusion");
for (auto& arg : *op->mutable_arg()) {
if (arg.name() == "fusion_type") {
CAFFE_ENFORCE(arg.i() == FUSION_CONV_SUM, "Invalid nest fusion");
// Only from FUSION_CONV_SUM to FUSION_CONV_SUM_RELU
arg.set_i(FUSION_CONV_SUM_RELU);
return;
}
}
CAFFE_THROW("Can not find fusion type in ConvFusion");
}
CAFFE_ENFORCE_LT(fusion_type, FUSION_CONV_SUM_RELU, "Invalid fusion type");
op->set_type("ConvFusion");
auto* arg = op->add_arg();
arg->set_name("fusion_type");
arg->set_i(fusion_type);
}
void removeArg(repr::NeuralNetOperator& nnOp, std::string argName) {
auto* op = getMutableOpDef(nnOp);
auto& opArgs = *op->mutable_arg();
auto remove_arg = [](decltype(opArgs)& args, std::string& name) {
for (auto it = args.begin(); it != args.end(); it++) {
if (it->name() == name) {
args.erase(it);
return true;
}
}
return false;
};
while (remove_arg(opArgs, argName))
;
}
void moveOpArg(
caffe2::Workspace* ws,
std::string argName,
repr::NeuralNetOperator* srcOp,
repr::NeuralNetOperator* dstOp) {
if (argName.empty() || srcOp == nullptr || dstOp == nullptr || srcOp == dstOp)
return;
removeArg(*dstOp, argName);
auto& src = getOpDef(*srcOp);
auto& src_args = src.arg();
auto src_it = src_args.begin();
for (; src_it != src_args.end(); src_it++) {
if (src_it->name() == argName)
break;
}
if (src_it == src_args.end())
return;
auto* dst = getMutableOpDef(*dstOp);
auto* arg = dst->add_arg();
*arg = *src_it;
arg->set_name(argName);
}
bool removeStopGradientForInference(repr::NNModule* nn, caffe2::Workspace* ws) {
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
for (int i = 0; i < allNodes.size(); ++i) {
auto node = allNodes[i];
if (!isOpType(node, "StopGradient")) {
continue;
}
auto stopGradInput = repr::nn::getInputs(node).front();
auto stopGradOutput = repr::nn::getOutputs(node).front();
auto inputName = repr::nn::get<repr::Tensor>(stopGradInput)->getName();
auto outputName = repr::nn::get<repr::Tensor>(stopGradOutput)->getName();
if (inputName == outputName) {
nn->dataFlow.replaceNode(stopGradOutput, stopGradInput);
nn->dataFlow.deleteNode(node);
return true;
}
}
return false;
}
bool fuseConvBNAndAffCh(repr::NNModule* nn, caffe2::Workspace* ws) {
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
bool no_bias = false;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
repr::NNGraph::NodeRef convNode;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
repr::Conv* conv;
std::tie(conv, convNode) = node_pair;
if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
const auto& convOp = getOpDef(*conv);
if (convOp.type() == "ConvFusion") {
continue;
}
auto convOutput = repr::nn::getOutputs(convNode).front();
auto consumers = repr::nn::getConsumers(convOutput);
// convOutput is NOT referenced by sequential ops after BN.
if (consumers.size() != 1) {
continue;
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
bool isBN;
auto consumer = consumers.front();
if (repr::nn::is<repr::BatchNormalization>(consumer)) {
isBN = true;
} else if (isOpType(consumer, "AffineChannel")) {
isBN = false;
} else {
continue;
}
auto bnOrAffChNode = consumer;
auto bn =
isBN ? repr::nn::get<repr::BatchNormalization>(bnOrAffChNode) : nullptr;
auto bnOrAffChOutput = repr::nn::getOutputs(bnOrAffChNode).front();
auto convInputs = repr::nn::getInputs(convNode);
if (convInputs.size() < 2) {
LOG(WARNING) << "Invalid convolution input size";
continue;
}
auto bnOrAffChInputs = repr::nn::getInputs(bnOrAffChNode);
int numInputs = isBN ? 5 : 3;
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
if (bnOrAffChInputs.size() < numInputs) {
LOG(WARNING) << "Invalid input size: " << bnOrAffChInputs.size()
<< ", expect " << numInputs;
continue;
}
// When no bias, borrow BN bias
if (convInputs.size() < 3) {
no_bias = true;
nn->dataFlow.createEdge(bnOrAffChInputs[2], convNode);
convInputs = repr::nn::getInputs(convNode);
}
#define EXPOSE_TENSOR_DATA(name, index, nodes, need_init) \
itensor* name = nullptr; \
itensor name##Tensor; \
float* name##Data = nullptr; \
if (need_init) { \
name = getMutableTensor<itensor>(getBlob(nodes[index], ws)); \
if (name == nullptr) { \
LOG(WARNING) << #name " not a IDEEP tensor"; \
continue; \
} \
name##Tensor.resize(name->get_dims(), name->get_data_type()); \
name##Tensor.feed_from(*name); \
CAFFE_ENFORCE( \
name##Tensor.is_public_format(), #name " not with public format"); \
name##Data = static_cast<float*>(name##Tensor.get_data_handle()); \
}
EXPOSE_TENSOR_DATA(filter, 1, convInputs, true);
EXPOSE_TENSOR_DATA(biasConv, 2, convInputs, true);
EXPOSE_TENSOR_DATA(scale, 1, bnOrAffChInputs, true);
EXPOSE_TENSOR_DATA(biasBNOrAffCh, 2, bnOrAffChInputs, true);
EXPOSE_TENSOR_DATA(mean, 3, bnOrAffChInputs, isBN);
EXPOSE_TENSOR_DATA(variance, 4, bnOrAffChInputs, isBN);
#undef EXPOSE_TENSOR_DATA
// Assume M{CHW,HWC}
auto chwDim = filterTensor.get_dim(1) * filterTensor.get_dim(2) *
filterTensor.get_dim(3);
for (auto c = 0; c < filterTensor.get_dim(0); ++c) {
float mean_val = 0;
float variance_val = 1;
if (isBN) {
mean_val = meanData[c];
variance_val = std::sqrt(varianceData[c] + bn->getEpsilon());
}
float coeff = scaleData[c] / variance_val;
for (auto i = 0; i < chwDim; ++i) {
filterData[c * chwDim + i] *= coeff;
}
if (no_bias) {
biasConvData[c] = biasBNOrAffChData[c] - mean_val * coeff;
} else {
biasConvData[c] =
biasBNOrAffChData[c] + (biasConvData[c] - mean_val) * coeff;
}
}
filter->feed_from(filterTensor);
biasConv->feed_from(biasConvTensor);
nn->dataFlow.replaceNode(convOutput, bnOrAffChOutput);
nn->dataFlow.deleteNode(bnOrAffChNode);
nn->dataFlow.deleteNode(convOutput);
return true;
}
return false;
}
bool fuseConvSum(repr::NNModule* nn, caffe2::Workspace* ws) {
CAFFE_ENFORCE(cpuinfo_initialize(), "failed to initialize cpuinfo");
// Assume the order of nodes from getMutableNodes conforms to
// the original topo order of operators
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int i = allNodes.size() - 1; i > 0; i--) {
auto sumNode = allNodes[i];
if (!repr::nn::hasInputs(sumNode)) {
continue;
}
// [Caution] on IDEEP device, only element-wise Add operator is
// supported yet. It totally works as element-wise sum without scalar
// broadcast.
bool is_dnnlowp_sum = false;
if (isOpType(sumNode, "Int8Sum") || isOpType(sumNode, "Int8Add") ||
isOpType(sumNode, "Int8SumRelu") || isOpType(sumNode, "Int8AddRelu")) {
is_dnnlowp_sum = true;
} else if (!repr::nn::is<repr::Sum>(sumNode) && !isOpType(sumNode, "Add")) {
continue;
}
auto sum = repr::nn::get<repr::NeuralNetOperator>(sumNode);
if (!isOnIdeepDevice(*sum)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
auto sumInputs = repr::nn::getInputs(sumNode);
if (sumInputs.size() != 2) {
continue;
}
int sum_idx = i;
repr::NNGraph::NodeRef convNode = nullptr;
while (--i >= 0) {
if (repr::nn::is<repr::NeuralNetOperator>(allNodes[i])) {
// Find the nearest conv Op before Sum
if (repr::nn::is<repr::Conv>(allNodes[i]) ||
isOpType(allNodes[i], "Int8Conv")) {
convNode = allNodes[i];
break;
}
}
}
if (convNode == nullptr || isConvFusion(convNode, FUSION_MAX)) {
continue;
}
int conv_idx = i;
auto conv = repr::nn::get<repr::NeuralNetOperator>(convNode);
if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
auto group = 1;
auto* convOp = getMutableOpDef(*conv);
for (const auto& arg : convOp->arg()) {
if (arg.name() == "group") {
group = arg.i();
break;
}
}
if (group > 1 && !cpuinfo_has_x86_avx512f()) {
LOG(WARNING) << "Not support conv sum fusion with grouped filter";
continue;
}
auto convOutput = repr::nn::getOutputs(convNode).front();
if (convOutput != sumInputs[0] && convOutput != sumInputs[1]) {
continue;
}
repr::NNGraph::NodeRef sumInputX =
(sumInputs[0] == convOutput ? sumInputs[1] : sumInputs[0]);
CAFFE_ENFORCE(sumInputX != nullptr, "Invalid sum inputs");
if (sumInputX->getInEdges().size() <= 0) {
continue;
}
auto preNode = repr::nn::getProducer(sumInputX);
if (preNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(preNode)) {
LOG(WARNING) << "Can not fuse Conv Sum";
continue;
}
int pre_idx = sum_idx - 1;
while (pre_idx >= 0) {
if (preNode == allNodes[pre_idx]) {
break;
}
pre_idx--;
}
bool should_fuse = true;
auto convInput = repr::nn::getInputs(convNode).front();
// NOLINTNEXTLINE(clang-diagnostic-sign-compare)
for (int idx = conv_idx + 1; idx < allNodes.size() - 1; ++idx) {
if (idx == sum_idx ||
!repr::nn::is<repr::NeuralNetOperator>(allNodes[idx])) {
continue;
}
auto checkNode = allNodes[idx];
auto checkInputs = repr::nn::getInputs(checkNode);
// Conv output should not be used by other ops after Conv node (except the
// fused Sum) The other Sum input (sumInputX) should not be used by the
// other ops after Sum node due to the Sum output is inplace with
// sumInputX
// NOLINTNEXTLINE(modernize-loop-convert)
for (size_t input_idx = 0; input_idx < checkInputs.size(); ++input_idx) {
if (convOutput == checkInputs[input_idx] ||
(idx > sum_idx && sumInputX == checkInputs[input_idx])) {
should_fuse = false;
break;
}
}
if (!should_fuse) {
break;
}
// If fuse Conv with Sum, the Conv op will be pulled down between preNode
// and Sum Check Conv input tensor buffer has been re-written by other ops
// between Conv and preNode
if (idx <= pre_idx) {
auto checkOutputs = repr::nn::getOutputs(checkNode);
// NOLINTNEXTLINE(modernize-loop-convert)
for (size_t output_idx = 0; output_idx < checkOutputs.size();
++output_idx) {
auto check_output_tensor =
repr::nn::get<repr::Tensor>(checkOutputs[output_idx]);
auto conv_input_tensor = repr::nn::get<repr::Tensor>(convInput);
if (conv_input_tensor->getName() == check_output_tensor->getName()) {
should_fuse = false;
break;
}
}
}
if (!should_fuse) {
break;
}
}
if (!should_fuse) {
continue;
}
nn->dataFlow.createEdge(sumInputX, convNode);
auto newOutputName = repr::nn::get<repr::Tensor>(sumInputX)->getName() +
"_fusion_fix_" + std::to_string(i);
auto newInputTensor = std::make_unique<repr::Tensor>(newOutputName);
auto newInput = nn->dataFlow.createNode(
unique_dyn_cast<repr::NeuralNetData>(newInputTensor));
nn->dataFlow.replaceNode(sumInputX, newInput);
nn->dataFlow.deleteNode(sumInputX);
auto newOutputTensor = std::make_unique<repr::Tensor>(newOutputName);
auto newOutput = nn->dataFlow.createNode(
unique_dyn_cast<repr::NeuralNetData>(newOutputTensor));
auto sumOutput = repr::nn::getOutputs(sumNode).front();
nn->dataFlow.replaceNode(sumOutput, newOutput);
nn->dataFlow.createEdge(convNode, newOutput);
if (!is_dnnlowp_sum) {
resetConvForFusion(convNode, FUSION_CONV_SUM);
} else {
moveOpArg(ws, "Y_scale", sum, conv);
moveOpArg(ws, "Y_zero_point", sum, conv);
if (isOpType(sumNode, "Int8Sum") || isOpType(sumNode, "Int8Add")) {
convOp->set_type("Int8ConvSum");
} else if (
isOpType(sumNode, "Int8SumRelu") ||
isOpType(sumNode, "Int8AddRelu")) {
convOp->set_type("Int8ConvSumRelu");
} else {
CAFFE_THROW("Unsupport operator in conv fusion");
}
}
nn->dataFlow.deleteNode(sumNode);
nn->dataFlow.deleteNode(sumOutput);
nn->dataFlow.deleteNode(convOutput);
return true;
}
return false;
}
bool fuseActivation(repr::NNModule* nn, caffe2::Workspace* ws) {
// Conv+Relu fusion
for (auto node_pair : repr::nn::dataIterator<repr::Conv>(nn->dataFlow)) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
repr::NNGraph::NodeRef conv_node;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
repr::Conv* conv;
std::tie(conv, conv_node) = node_pair;
// Check topological feasibility
auto conv_outputs = repr::nn::getOutputs(conv_node);
if (conv_outputs.size() != 1) {
continue;
}
auto conv_output = conv_outputs.front();
auto consumers = repr::nn::getConsumers(conv_output);
if (consumers.size() != 1) {
continue;
}
if (!repr::nn::is<repr::Relu>(consumers.front())) {
continue;
}
auto relu_node = consumers.front();
auto relu_outputs = repr::nn::getOutputs(relu_node);
if (relu_outputs.size() != 1) {
continue;
}
// Check feasibility with application specific logic
if (!isOnIdeepDevice(*conv)) {
continue;
}
// Ready to fuse
auto relu_output = relu_outputs.front();
auto output_tensor = repr::nn::get<repr::Tensor>(relu_output);
auto output_node = relu_output;
auto input_tensor =
repr::nn::get<repr::Tensor>(repr::nn::getInputs(conv_node).front());
if (isConvFusion(conv_node, FUSION_CONV_SUM)) {
nn->dataFlow.replaceNode(relu_output, conv_output);
nn->dataFlow.deleteNode(relu_node);
nn->dataFlow.deleteNode(relu_output);
} else {
// Conv cannot be in-place
if (output_tensor->getName() != input_tensor->getName()) {
nn->dataFlow.replaceNode(conv_output, relu_output);
nn->dataFlow.deleteNode(relu_node);
nn->dataFlow.deleteNode(conv_output);
} else {
nn->dataFlow.replaceNode(relu_output, conv_output);
output_tensor = repr::nn::get<repr::Tensor>(conv_output);
output_node = conv_output;
nn->dataFlow.deleteNode(relu_node);
nn->dataFlow.deleteNode(relu_output);
}
// We may have accidentally made the next op in-place
// In future iterations of transformations this won't be an issue,
// but current caffe2 predictor usage requires things like
// external_input and output to be unchanged.
bool rectify_inplace = false;
for (auto& consumer : repr::nn::getConsumers(output_node)) {
for (auto& consumer_output : repr::nn::getOutputs(consumer)) {
auto co_name =
repr::nn::get<repr::Tensor>(consumer_output)->getName();
if (co_name == output_tensor->getName()) {
rectify_inplace = true;
}
}
}
if (rectify_inplace) {
auto new_output = nn->dataFlow.createNode(make_unique<repr::Tensor>(
output_tensor->getName() + "_fusion_fix"));
nn->dataFlow.replaceNode(output_node, new_output);
}
}
resetConvForFusion(conv_node, FUSION_CONV_RELU);
return true;
}
return false;
}
bool enforceFusionInplace(repr::NNModule* nn, caffe2::Workspace* ws) {
// For fusions of Conv+Sum or Conv+Sum+ReLU, the last input and output must
// be inplaced. To enforce inplace, here to re-check whole graph and correct
// the ConvFusion Ops.
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
for (int i = allNodes.size() - 1; i > 0; i--) {
auto convNode = allNodes[i];
if (convNode == nullptr ||
!repr::nn::is<repr::NeuralNetOperator>(convNode)) {
continue;
}
auto conv = repr::nn::get<repr::NeuralNetOperator>(convNode);
if (!isOnIdeepDevice(*conv)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
if (repr::nn::is<repr::Conv>(convNode)) {
if (!isConvFusion(convNode, FUSION_CONV_SUM) &&
!isConvFusion(convNode, FUSION_CONV_SUM_RELU))
continue;
} else if (
!isOpType(convNode, "Int8ConvSum") &&
!isOpType(convNode, "Int8ConvSumRelu")) {
continue;
}
auto convInput = repr::nn::getInputs(convNode).back();
auto inputName = repr::nn::get<repr::Tensor>(convInput)->getName();
auto convOutput = repr::nn::getOutputs(convNode).front();
auto outputName = repr::nn::get<repr::Tensor>(convOutput)->getName();
if (inputName == outputName) {
continue;
}
auto consumer = repr::nn::getConsumers(convInput).back();
if (consumer != convNode) {
LOG(ERROR) << "Can not enforce to inplace for fusion";
return false;
}
auto newOutputTensor = std::make_unique<repr::Tensor>(inputName);
auto newOutput = nn->dataFlow.createNode(
unique_dyn_cast<repr::NeuralNetData>(newOutputTensor));
nn->dataFlow.replaceNode(convOutput, newOutput);
nn->dataFlow.deleteNode(convOutput);
return true;
}
return false;
}
bool fuseOrderSwitchToQuantizeOp(repr::NNModule* nn, caffe2::Workspace* ws) {
// In INT8 module, the quantize/dequantize op always appears
// along with corresponding order switch op, which aims to switch
// between INT8 computation domain and others.
// Here we assume they always obey below combination and order:
// NCHW2NHWC followed by Int8Quantize, or Int8Dequantize followed by NHWC2NCHW
// On iDEEP, there is chance to fuse the order switch op into the
// quantize/dequantize op, in order to improve the module performance.
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
for (int i = 0; i < allNodes.size(); ++i) {
auto osNode = allNodes[i];
if (osNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(osNode)) {
continue;
}
if (isOpType(osNode, "NCHW2NHWC")) {
auto output = repr::nn::getOutputs(osNode).front();
auto consumers = repr::nn::getConsumers(output);
if (consumers.size() != 1) {
continue;
}
auto seqNode = consumers.front();
if (!isOpType(seqNode, "Int8Quantize")) {
continue;
}
auto seq = repr::nn::get<repr::NeuralNetOperator>(seqNode);
removeArg(*seq, "output_order");
auto* seqOp = getMutableOpDef(*seq);
auto* arg = seqOp->add_arg();
arg->set_name("output_order");
arg->set_i(static_cast<int64_t>(iformat::nhwc));
auto input = repr::nn::getInputs(osNode).front();
nn->dataFlow.replaceNode(output, input);
nn->dataFlow.deleteNode(osNode);
nn->dataFlow.deleteNode(output);
return true;
} else if (isOpType(osNode, "NHWC2NCHW")) {
auto input = repr::nn::getInputs(osNode).front();
if (input->getInEdges().size() <= 0) {
continue;
}
auto preNode = repr::nn::getProducer(input);
if (!isOpType(preNode, "Int8Dequantize")) {
continue;
}
auto pre = repr::nn::get<repr::NeuralNetOperator>(preNode);
removeArg(*pre, "output_order");
auto* preOp = getMutableOpDef(*pre);
auto* arg = preOp->add_arg();
arg->set_name("output_order");
arg->set_i(static_cast<int64_t>(iformat::nchw));
auto output = repr::nn::getOutputs(osNode).front();
nn->dataFlow.replaceNode(input, output);
nn->dataFlow.deleteNode(osNode);
nn->dataFlow.deleteNode(input);
return true;
}
}
return false;
}
bool fusePreConvertOp(repr::NNModule* nn, caffe2::Workspace* ws) {
// 1. Int8Sum has been fallbacked to FP32 in current impl
// It can handle inputs with diff format and data type
// 2. FC is able to convert input format and data type by itself
// 3. The fallback wrapper can handle the conversion of format and data type
static vector<string> op_list = {
"FC",
"Python",
"Softmax",
"Sigmoid",
"RoIAlign",
"UpsampleNearest",
"BatchPermutation",
"Int8Sum",
"Int8SumRelu",
};
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
for (int i = 0; i < allNodes.size(); ++i) {
auto opNode = allNodes[i];
if (opNode == nullptr || !repr::nn::is<repr::NeuralNetOperator>(opNode)) {
continue;
}
if (!isOpType(opNode, "NCHW2NHWC") && !isOpType(opNode, "NHWC2NCHW") &&
!isOpType(opNode, "Int8Quantize") &&
!isOpType(opNode, "Int8Dequantize")) {
continue;
}
auto op = repr::nn::get<repr::NeuralNetOperator>(opNode);
if (!isOnIdeepDevice(*op)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
auto output = repr::nn::getOutputs(opNode).front();
auto consumers = repr::nn::getConsumers(output);
if (consumers.size() != 1) {
continue;
}
bool is_op_found = false;
auto seqNode = consumers.front();
// NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
for (int j = 0; j < op_list.size(); j++) {
if (isOpType(seqNode, op_list[j])) {
is_op_found = true;
break;
}
}
if (!is_op_found) {
continue;
}
auto seqOp = repr::nn::get<repr::NeuralNetOperator>(seqNode);
if (!isOnIdeepDevice(*seqOp)) {
LOG(WARNING) << "Not a IDEEP operator";
continue;
}
auto input = repr::nn::getInputs(opNode).front();
if (isOpType(opNode, "Int8Dequantize") &&
repr::nn::hasSingleOutputAndConsumer(opNode)) {
auto preNode = repr::nn::getProducer(input);
if (isOpType(preNode, "Int8FC") &&
repr::nn::hasSingleOutputAndConsumer(preNode)) {
auto predOp = repr::nn::get<repr::NeuralNetOperator>(preNode);
removeArg(*predOp, "Y_scale");
removeArg(*predOp, "Y_zero_point");
}
}
nn->dataFlow.replaceNode(output, input);
nn->dataFlow.deleteNode(opNode);
nn->dataFlow.deleteNode(output);
return true;
}
return false;
}
void setPoolingInferenceMode(repr::NNModule* nn) {
auto setTrainingMode = [](repr::NeuralNetOperator& pool) {
if (!isOnIdeepDevice(pool)) {
LOG(WARNING) << "Not a IDEEP operator";
return;
}
auto* op = getMutableOpDef(pool);
bool found_training_mode = false;
for (auto& arg : *op->mutable_arg()) {
if (arg.name() == "training_mode") {
arg.set_i(0);
found_training_mode = true;
break;
}
}
if (!found_training_mode) {
auto* arg = op->add_arg();
arg->set_name("training_mode");
arg->set_i(0);
}
};
auto allNodes = nn->dataFlow.getMutableNodes();
// NOLINTNEXTLINE(modernize-loop-convert,clang-diagnostic-sign-compare)
for (int i = 0; i < allNodes.size(); ++i) {
auto poolNode = allNodes[i];
if (poolNode == nullptr ||
!repr::nn::is<repr::NeuralNetOperator>(poolNode)) {
continue;
}
if (isOpType(poolNode, "FC") || isOpType(poolNode, "Conv") ||
isOpType(poolNode, "ConvFusion") || isOpType(poolNode, "MaxPool") ||
isOpType(poolNode, "AveragePool") || isOpType(poolNode, "Int8FC") ||
isOpType(poolNode, "Int8Conv") || isOpType(poolNode, "Int8ConvRelu") ||
isOpType(poolNode, "Int8ConvSum") ||
isOpType(poolNode, "Int8ConvSumRelu") ||
isOpType(poolNode, "Int8MaxPool") ||
isOpType(poolNode, "Int8AveragePool")) {
auto pool = repr::nn::get<repr::NeuralNetOperator>(poolNode);
setTrainingMode(*pool);
}
}
}
// Pre-convert filters format to expected one here
// in order to avoid boring conversions during computations
void preConvertFiltersFormat(repr::NNModule* nn, caffe2::Workspace* ws) {
for (auto& node : nn->dataFlow.getMutableNodes()) {
if (!repr::nn::is<repr::ConvTranspose>(node) &&
!repr::nn::is<repr::Conv>(node) && !repr::nn::is<repr::FC>(node)) {
continue;
}
auto* nnOp = repr::nn::get<repr::NeuralNetOperator>(node);
if (!isOnIdeepDevice(*nnOp)) {
LOG(INFO) << "Not a IDEEP operator";
continue;
}
auto inputs = repr::nn::getInputs(node);
if (inputs.size() < 2) {
LOG(WARNING) << "Invalid input size";
continue;
}
auto* filterBlob = getBlob(inputs[1], ws);
auto* filter = getMutableTensor<itensor>(filterBlob);
if (filter == nullptr) {
continue;
}
itensor::descriptor expectedDesc;
if (repr::nn::is<repr::ConvTranspose>(node)) {
if (filter->get_desc().is_iohw())
continue;
auto convTranspose = repr::nn::get<repr::ConvTranspose>(node);
auto initValue = [](vector<int>& v, vector<int> i) {
if (v.empty())
v = i;
};
auto strides = convTranspose->getStrides();
initValue(strides, {1, 1});
auto pads = convTranspose->getPads();
initValue(pads, {0, 0, 0, 0});
auto dataType = filter->get_data_type();
ideep::tensor::dims filter_dims_mkldnn{filter->get_dim(1),
filter->get_dim(0),
filter->get_dim(2),
filter->get_dim(3)};
expectedDesc =
ideep::convolution_transpose_forward::expected_weights_desc(
filter_dims_mkldnn,
dataType,
{strides.begin(), strides.end()},
{pads[0], pads[1]},
{pads[2], pads[3]});
if (filter->get_descriptor() != expectedDesc) {
itensor newFilter;
newFilter.init(expectedDesc);
newFilter.feed_from(*filter);
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
}
} else if (repr::nn::is<repr::Conv>(node)) {
auto conv = repr::nn::get<repr::Conv>(node);
auto initValue = [](vector<int>& v, vector<int> i) {
if (v.empty())
v = i;
};
auto strides = conv->getStrides();
initValue(strides, {1, 1});
auto pads = conv->getPads();
initValue(pads, {0, 0, 0, 0});
auto dilations = conv->getDilations();
initValue(dilations, {1, 1});
auto* op = getMutableOpDef(*conv);
auto aalgorithm = ialgo::convolution_direct;
for (auto& arg : *op->mutable_arg()) {
if ((arg.name() == "conv_algorithm") &&
(arg.i() == CONV_ALGORITHM_WINOGRAD)) {
aalgorithm = ialgo::convolution_winograd;
}
}
expectedDesc = ideep::convolution_forward::expected_weights_desc(
filter->get_dims(),
filter->get_data_type(),
{strides.begin(), strides.end()},
{pads[0], pads[1]},
{pads[2], pads[3]},
{dilations.begin(), dilations.end()},
conv->getGroup(),
aalgorithm);
if (filter->get_descriptor() != expectedDesc) {
itensor newFilter;
newFilter.init(expectedDesc);
newFilter.feed_from(*filter);
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
}
// convert weights for FC
} else if (repr::nn::is<repr::FC>(node)) {
auto fc = repr::nn::get<repr::FC>(node);
auto axis_w = fc->getAxisW();
if (axis_w != 1) {
auto f_dims = filter->get_dims();
auto f_dim0 = std::accumulate(
f_dims.begin(),
f_dims.begin() + axis_w,
1,
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::multiplies<itensor::dim_t>());
auto f_dim1 = std::accumulate(
f_dims.begin() + axis_w,
f_dims.end(),
1,
// NOLINTNEXTLINE(modernize-use-transparent-functors)
std::multiplies<itensor::dim_t>());
filter->reshape({f_dim0, f_dim1});
}
expectedDesc = ideep::inner_product_forward::expected_weights_desc(
filter->get_dims());
if (filter->get_descriptor() != expectedDesc) {
itensor newFilter;
newFilter.init(expectedDesc);
newFilter.feed_from(*filter);
filterBlob->Reset<itensor>(new itensor(std::move(newFilter)));
}
}
}
}
// Fusers for ideep to parse the graph and apply operator fusion
using Fuser = bool (*)(repr::NNModule* nn, caffe2::Workspace* ws);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables,cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static Fuser fusers[] = {
removeStopGradientForInference,
fuseConvBNAndAffCh,
fuseConvSum,
fuseActivation,
enforceFusionInplace,
fuseOrderSwitchToQuantizeOp,
fusePreConvertOp,
};
void OptimizeForMkldnn(
repr::NNModule* nn,
caffe2::Workspace* ws,
bool training_mode) {
if (training_mode) {
preConvertFiltersFormat(nn, ws);
return;
}
for (auto fuser : fusers) {
while (fuser(nn, ws)) {
}
}
setPoolingInferenceMode(nn);
}
#endif // USE_MKLDNN
} // namespace opt
} // namespace caffe2
|