1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
|
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
#include <torch/csrc/jit/codegen/onednn/interface.h>
#include <torch/csrc/jit/codegen/onednn/kernel.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/operator_options.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void fuseGraph(std::shared_ptr<Graph>& g) {
// Follow the process of the tensorexpr_fuser in profiling mode:
// Remove prim::profile nodes and embed the profile info directly in the
// IR in value types to avoid breaking the fusion patterns.
// Will add shape guard after LLGA optimization passes and
// wipe the tensor type information from the IR, so that it's not
// accidentally used by any other pass.
// We rely on the shape specialization and shape guard to ensure the validity
// of the cached compilation in the kernel, thus only support profiling mode.
// TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
// to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
if (getProfilingMode()) {
GRAPH_DUMP(
"Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
"optimization pass",
g);
RemoveProfileNodesAndSpecializeTypes(g);
GRAPH_DUMP(
"After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
g);
RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
static std::unordered_set<Symbol> supportedOps = {
aten::add_,
aten::mul_,
aten::tanh_,
aten::elu_,
aten::relu_,
aten::relu6_,
aten::gelu_,
aten::sqrt_,
aten::sigmoid_,
aten::hardtanh_,
aten::abs_,
aten::square_,
};
return supportedOps.count(nodeToFunctionalize->kind()) != 0;
});
RemoveListMutation(g);
GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g);
PrepareBinaryForLLGA(g);
GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
DeferSizeCheck(g);
GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
CreateLlgaSubgraphs(g);
GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
PropagateLayout(g);
GRAPH_DUMP(
"After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
// Add shape guard for profiling mode and wipe the tensor type information
// from the IR
prepareFusionGroupAndGuardOutputs(g->block());
GRAPH_DUMP(
"After prepareFusionGroupAndGuardOutputs. Before "
"RemoveTensorTypeSpecializations",
g);
RemoveTensorTypeSpecializations(g);
GRAPH_DUMP(
"After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
g);
}
}
} // namespace onednn
} // namespace fuser
Operation createLlgaKernel(const Node* node) {
auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
return [kernel](Stack* stack) {
RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
kernel->run(*stack);
return 0;
};
}
RegisterOperators oneDNNFusionGroupOp({
torch::jit::Operator(
prim::oneDNNFusionGroup,
createLlgaKernel,
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
// Currently, we convert some scalar inputs, such as the second argument of
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
// But if we have any scalar inputs to guard in the future, some logic here
// would have to be changed.
Operation createLlgaGuardKernel(const Node* node) {
return [node](Stack* stack) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
#endif
std::vector<TypePtr> types = node->tys(attr::types);
const auto num_inputs = types.size();
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
#endif
for (size_t i = 0; i < num_inputs; i++) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("checking input ", i);
#endif
auto& input = peek(stack, i, num_inputs);
const c10::TensorTypePtr& guard_tensor_type =
types[i]->cast<TensorType>();
if (!input.isTensor()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is not a tensor, return false");
#endif
push(stack, IValue(false));
return;
}
const at::Tensor& tensor = input.toTensor();
// If input tensor is of mkldnn, it's originated from an upstream
// LLGA partition that has passed the check on input shapes.
// It is valid to continue here as long as the output shapes from
// oneDNN graph partitions are determined by the input shapes.
if (tensor.is_mkldnn()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
#endif
continue;
}
if (!guard_tensor_type->matchTensor(tensor)) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " check failed, return false");
#endif
push(stack, IValue(false));
return;
}
}
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("all check done, return true");
#endif
push(stack, IValue(true));
return;
};
}
RegisterOperators oneDNNGuardOp({
torch::jit::Operator(
prim::oneDNNFusionGuard,
createLlgaGuardKernel,
AliasAnalysisKind::FROM_SCHEMA),
});
} // namespace jit
} // namespace torch
|