1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
|
#include <torch/csrc/jit/ir/constants.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/fold_linear_bn.h>
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
#include <torch/csrc/jit/passes/utils/optimization_utils.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#else
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>
#endif
namespace torch::jit {
namespace {
using Tensor = at::Tensor;
bool supportedLinearNode(Node* n) {
if (n->kind() == aten::linear) {
return true;
} else {
return false;
}
}
bool FoldFrozenLinearBatchnorm(Block* b) {
bool graph_modified = false;
for (Node* n : b->nodes()) {
for (Block* block : n->blocks()) {
graph_modified |= FoldFrozenLinearBatchnorm(block);
}
if (n->kind() == aten::batch_norm &&
supportedLinearNode(n->inputs().at(0)->node())) {
auto linear = n->inputs().at(0)->node();
auto bn = n;
if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
continue;
}
auto bn_rm_ivalue = bn->namedInput("running_mean");
auto bn_rv_ivalue = bn->namedInput("running_var");
// check running_mean and running_var has value, if they are
// None(track_running_stats=False), skipping the folding path.
if (bn_rm_ivalue->type() == NoneType::get() &&
bn_rv_ivalue->type() == NoneType::get()) {
continue;
}
auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
int64_t linear_out_features = linear_w.size(0);
int64_t bn_num_features = bn_rm.size(0);
// Linear-BN needs to be fused while preserving the shapes of linear
// weight/bias. To preserve the shapes of linear weight/bias, the channel
// dim of bn needs to be broadcastable with the last dim of linear,
// because bn operates over the channel dim, (N, C_in, H, W) while linear
// operates over the last dim, (*, H_in). To be broadcastable, the number
// of features in bn and the number of output features from linear must
// satisfy the following condition:
// 1. they are equal, or
// 2. the number of features in bn is 1
// Otherwise, skip the folding path
if (!(linear_out_features == bn_num_features || bn_num_features == 1)) {
continue;
}
// implementation taken from torch/nn/utils/fusion.py
Tensor linear_b;
if (linear->namedInput("bias")->type() == NoneType::get()) {
at::ScalarType bias_dtype = bn_rm.scalar_type();
at::ScalarType weight_dtype = linear_w.scalar_type();
at::DeviceType weight_device = linear_w.device().type();
if (weight_device == at::kCUDA &&
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
bias_dtype == at::kFloat) {
bias_dtype = weight_dtype;
}
linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
} else {
linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
}
Tensor bn_w;
if (bn->namedInput("weight")->type() == NoneType::get()) {
bn_w = at::ones_like(bn_rm);
} else {
bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
}
Tensor bn_b;
if (n->namedInput("bias")->type() == NoneType::get()) {
bn_b = at::zeros_like(bn_rm);
} else {
bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
}
LinearBNParameters params;
params.linear_w = linear_w;
params.linear_b = linear_b;
params.bn_rm = bn_rm;
params.bn_rv = bn_rv;
params.bn_eps = bn_eps;
params.bn_w = bn_w;
params.bn_b = bn_b;
std::tuple<Tensor, Tensor> out =
computeUpdatedLinearWeightAndBias(params);
WithInsertPoint guard(linear);
auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
auto linear_w_value = linear->namedInput("weight");
auto linear_b_value = linear->namedInput("bias");
fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
linear->replaceInputWith(linear_w_value, fused_linear_w);
linear->replaceInputWith(linear_b_value, fused_linear_b);
bn->output()->replaceAllUsesWith(linear->output());
graph_modified = true;
}
}
return graph_modified;
}
} // namespace
bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
EliminateDeadCode(graph);
return graph_modified;
}
} // namespace torch::jit
|