1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
|
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/passes/constant_propagation.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <ATen/core/symbol.h>
namespace torch {
namespace jit {
namespace {
c10::AliasAnalysisKind aliasAnalysisFromSchema() {
return c10::AliasAnalysisKind::FROM_SCHEMA;
}
} // namespace
// helper to determine if an optional tensor argument/value passed in is
// statically defined (neither a None constant nor a Optional[Tensor] type)
// return yes, no, or no value if we can't tell
c10::optional<bool> isDefined(Value* tensor) {
if (tensor->type()->isSubtypeOf(*TensorType::get())) {
return true;
}
if (tensor->node()->mustBeNone()) {
return false;
}
return {};
}
bool isDecomposableNorm(Node* normalize_op) {
static const OperatorSet decomposable_normalization_ops = {
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor",
};
Value* input = normalize_op->namedInput(attr::input);
if (!input->type()->isSubtypeOf(*TensorType::get())) {
return false;
}
auto device = input->type()->expectRef<TensorType>().device();
// As of now, we do the decomposition for batchnorm/layernorm on GPU device
// only
if (!device || !(*device).is_cuda()) {
return false;
}
if (normalize_op->isMemberOf(decomposable_normalization_ops)) {
// If we can't determine if weight and bias is defined statically there's
// really no point in decomposing normalization into simpler ops, since it
// won't get fused into a single kernel.
return isDefined(normalize_op->namedInput(attr::weight)).has_value() &&
isDefined(normalize_op->namedInput(attr::bias)).has_value();
}
return false;
}
RegisterOperators reg_ops(
{Operator(
"aten::_ncf_unsqueeze(Tensor(a) self, int ndim) -> Tensor(a)",
[](Stack& stack) {
const int64_t ndim = pop(stack).toInt();
auto self = pop(stack).toTensor();
c10::SmallVector<int64_t, 8> sizes(ndim, 1);
AT_ASSERT(self.dim() == 1);
sizes.at(1) = self.size(0);
push(stack, self.reshape(sizes));
},
aliasAnalysisFromSchema()),
Operator(
"aten::_ncf_view(Tensor(a) self, int[] input_shape, int normalized_ndim) -> Tensor(a)",
[](Stack& stack) {
const int64_t normalized_ndim = pop(stack).toInt();
auto input_shape = pop(stack).toIntList();
auto self = pop(stack).toTensor();
const int64_t input_ndim = input_shape.size();
c10::SmallVector<int64_t, 8> sizes(input_ndim, 1);
for (int i = 0; i < input_ndim - normalized_ndim; ++i) {
sizes.at(i) = input_shape.get(i);
}
push(stack, self.reshape(sizes));
},
aliasAnalysisFromSchema())});
bool DecomposeOps(Block* block, CompilationUnit& decompose_funcs) {
bool decomposed = false;
for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
++it) {
for (auto sub : it->blocks()) {
DecomposeOps(sub, decompose_funcs);
}
if (it->matches(
"aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor",
/*const_inputs=*/{attr::beta, attr::alpha})) {
// For the case where we have an addmm where alpha and beta are Attributes
// and both of those scalars are equal to 1.0, decompose this into an mm
// followed by an add so that it can go through the existing optimization
// (batchmm)
if (it->get<at::Scalar>(attr::alpha)->toComplexDouble() != 1.0 ||
it->get<at::Scalar>(attr::beta)->toComplexDouble() != 1.0) {
continue;
}
decomposed = true;
WithInsertPoint guard(*it);
std::shared_ptr<Graph> d_graph =
toGraphFunction(decompose_funcs.get_function("addmm")).graph();
Value* new_output =
insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
// Set the output of the decomposed graph to have the same output type as
// the original op otherwise the canonicalized graph will have TensorType
// as the output of this node which is incorrect
new_output->setType(it->output()->type());
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
} else if (
it->matches(
"aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor")) {
if (!isDecomposableNorm(*it)) {
continue;
}
decomposed = true;
WithInsertPoint insert_guard{*it};
Graph* graph = it->owningGraph();
Value* input = it->namedInput(attr::input);
Value* input_dim = graph->insert(aten::dim, {input});
std::vector<Value*> inputs{
input,
it->namedInput(attr::running_mean),
it->namedInput(attr::running_var),
it->namedInput(attr::training),
it->namedInput(attr::momentum),
it->namedInput(attr::eps)};
// inline the compiled decomposed batchnorm
std::shared_ptr<Graph> d_graph =
toGraphFunction(decompose_funcs.get_function("batch_norm")).graph();
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
// post processing the graph
Value* weight = it->namedInput(attr::weight);
Value* bias = it->namedInput(attr::bias);
if (isDefined(weight).value()) {
Value* expanded_weight =
graph->insert(aten::_ncf_unsqueeze, {weight, input_dim});
new_output = graph->insert(aten::mul, {new_output, expanded_weight});
}
if (isDefined(bias).value()) {
Value* expanded_bias =
graph->insert(aten::_ncf_unsqueeze, {bias, input_dim});
new_output = graph->insert(aten::add, {new_output, expanded_bias});
}
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
} else if (
it->matches(
"aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? weight, Tensor? bias, float eps, bool cudnn_enable) -> Tensor")) {
if (!isDecomposableNorm(*it)) {
continue;
}
decomposed = true;
WithInsertPoint insert_guard{*it};
Graph* graph = it->owningGraph();
std::vector<Value*> inputs{
it->namedInput(attr::input),
it->namedInput(attr::normalized_shape),
it->namedInput(attr::eps),
it->namedInput(attr::cudnn_enable)};
// inline the compiled decomposed layernorm
std::shared_ptr<Graph> d_graph =
toGraphFunction(decompose_funcs.get_function("layer_norm")).graph();
Value* new_output = insertGraph(*graph, *d_graph, inputs).at(0);
// post processing the graph
Value* weight = it->namedInput(attr::weight);
Value* bias = it->namedInput(attr::bias);
if (isDefined(weight).value()) {
new_output = graph->insert(aten::mul, {new_output, weight});
}
if (isDefined(bias).value()) {
new_output = graph->insert(aten::add, {new_output, bias});
}
it->output()->replaceAllUsesWith(new_output);
it.destroyCurrent();
}
}
return decomposed;
}
void DecomposeOps(std::shared_ptr<Graph>& graph) {
static CompilationUnit decompose_funcs(R"SCRIPT(
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: number = 1.0, alpha: number = 1.0):
return self + mat1.mm(mat2)
def batch_norm(input : Tensor, running_mean : Optional[Tensor], running_var : Optional[Tensor], training : bool, momentum : float, eps : float) -> Tensor:
if training:
norm_mean, norm_var = torch.batch_norm_update_stats(input, running_mean, running_var, momentum)
else:
norm_mean = torch._unwrap_optional(running_mean)
norm_var = torch._unwrap_optional(running_var)
norm_mean = torch._ncf_unsqueeze(norm_mean, input.dim())
norm_var = torch._ncf_unsqueeze(norm_var, input.dim())
norm_invstd = 1 / (torch.sqrt(norm_var + eps))
return ((input - norm_mean) * norm_invstd)
def layer_norm(input : Tensor, normalized_shape : List[int], eps : float, cudnn_enable : bool) -> Tensor:
input_ndim = input.dim()
normalized_ndim = len(normalized_shape)
n = 1
for i in range(input_ndim - normalized_ndim):
n *= input.size(i)
input_reshape = input.contiguous().view(1, n, -1)
mean, invstd = torch.batch_norm_stats(input_reshape, eps)
input_shape = input.size()
mean = torch._ncf_view(mean, input_shape, normalized_ndim)
invstd = torch._ncf_view(invstd, input_shape, normalized_ndim)
return (input - mean) * invstd
)SCRIPT");
bool is_decomposed = DecomposeOps(graph->block(), decompose_funcs);
if (is_decomposed) {
// we only re-run those passes when the graph get decomposed
PropagateInputShapes(graph);
ConstantPropagation(graph);
EliminateDeadCode(graph);
}
}
} // namespace jit
} // namespace torch
|