1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
|
#include <aten/src/ATen/core/jit_type.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/shape_analysis.h>
namespace torch::jit::fuser::onednn {
static bool compareConstValue(Value* v, double d) {
auto ival = toIValue(v);
return ival.has_value() &&
((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
(ival->isDouble() && ival->toDouble() == d));
}
static void handleBinaryOpInputs(Node* node) {
// We do not handle binary ops with two scalar inputs,
// and we assume scalar is always at the second place.
if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
auto dtypeOfFirstInput =
node->input(0)->type()->cast<TensorType>()->scalarType().value();
if (node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
node->input(1)->type()->isSubtypeOf(IntType::get())) {
// If a scalar is added to be a tensor, we would assume that the
// scalar is of the same dtype as the tensor, as oneDNN graph
// currently requires inputs of binary ops to have the same dtype.
// We create a 1D tensor from the scalar input & "promote" its
// dtype to that of the first input. Doing so helps us satisfy PyTorch's
// type promotion rules.
// Although we convert the scalar to a tensor, we still need to promote
// types, as if the second input were still a scalar.
// The following sample code-snippet illustrates that converting a scalar
// input to a 1-D tensor may result in a different output dtype than would
// otherwise have been the case.
// clang-format off
// >>> (1. + torch.rand([2]).half()).dtype
// torch.float16
// >>> (torch.tensor(1.).unsqueeze(0) + (torch.rand([2]).half())).dtype
// torch.float32
// clang-format on
auto promotedDtype = dtypeOfFirstInput;
auto scalar = node->input(1);
WithInsertPoint guard(node);
auto g = node->owningGraph();
// 42 : Scalar --> tensor(42.0) : Float([])
auto t = g->insert(aten::as_tensor, {scalar}, {{"dtype", promotedDtype}});
// add dim & stride info to IR
std::optional<size_t> t_dim = 1;
auto target_type = TensorTypePtr(
TensorType::create(promotedDtype, at::kCPU, t_dim, false));
target_type = target_type->withSizes({1});
t->setType(target_type);
// tensor(42.0) : Float([]) --> tensor([42.0]) : Float([1])
auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
unsqueezed->setType(target_type);
node->replaceInput(1, unsqueezed);
// dtype might have changed, so needs to be updated in IR as well
node->output()->setType(
node->output()->type()->expect<TensorType>()->withScalarType(
promotedDtype));
} else if (node->input(1)->type()->isSubtypeOf(TensorType::get())) {
// Here, both inputs are tensors, and we just wanna make sure that they
// are the same dtype, as oneDNN Graph requires both inputs to have the
// same dtype. We'll follow PyTorch's type-promotion rules here.
auto second_input_typeptr = node->input(1)->type()->expect<TensorType>();
std::optional<at::ScalarType> second_input_type =
second_input_typeptr->scalarType();
if (second_input_type != std::nullopt) {
// dtype of the second tensor might not be available in the IR
auto dtypeOfSecondInput = second_input_type.value();
if (dtypeOfFirstInput != dtypeOfSecondInput) {
// Type promotion is required
auto promotedDtype =
c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
WithInsertPoint guard(node);
auto g = node->owningGraph();
if (promotedDtype == dtypeOfFirstInput) {
auto to_node_output = g->insert(
aten::to, {node->input(1)}, {{"dtype", promotedDtype}});
to_node_output->setType(
node->input(1)->type()->expect<TensorType>()->withScalarType(
promotedDtype));
node->replaceInput(1, to_node_output);
} else {
auto to_node_output = g->insert(
aten::to, {node->input(0)}, {{"dtype", promotedDtype}});
to_node_output->setType(
node->input(0)->type()->expect<TensorType>()->withScalarType(
promotedDtype));
node->replaceInput(0, to_node_output);
}
// dtype might have changed, so needs to be updated in IR as well
node->output()->setType(
node->output()->type()->expect<TensorType>()->withScalarType(
promotedDtype));
} else {
// both dtypes are same
// IR info of dtypes is missing sometimes in JIT IR,
// and we shouldn't treat those tensors as FP32 tensors by default.
node->output()->setType(
node->output()->type()->expect<TensorType>()->withScalarType(
dtypeOfFirstInput));
}
} // end inner if block
} // end outer if block
}
}
static void ConvertScalarToTensor(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
ConvertScalarToTensor(sub);
}
if (node->kind() == aten::add || node->kind() == aten::mul ||
node->kind() == aten::div) {
handleBinaryOpInputs(node);
}
}
}
static void mayDecomposeAdd(Node* node) {
if (node->inputs().size() < 3) {
return; // corner-case in BERT-mrpc that's not in line with
// native_functions.yaml
}
if (toIValue(node->namedInput("alpha")).has_value()) {
auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
if (!alphaEqualsOne) {
WithInsertPoint guard(node);
auto g = node->owningGraph();
auto mul = g->insert(
aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
if (node->namedInput("other")->type()->isSubtypeOf(TensorType::get())) {
auto mulTensorTypePtr = node->namedInput("other")->type();
mul->setType(mulTensorTypePtr);
}
node->replaceInput(1, mul);
auto one = g->insertConstant(1.0);
node->replaceInput(2, one);
}
}
}
static void DecomposeFusedAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
DecomposeFusedAdd(sub);
}
if (node->kind() == aten::add) {
mayDecomposeAdd(node);
}
}
}
static void EliminateIdentityMulAdd(Block* block) {
for (auto node : block->nodes()) {
for (auto sub : node->blocks()) {
EliminateIdentityMulAdd(sub);
}
if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
(node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
node->output()->replaceAllUsesWith(node->namedInput("self"));
}
}
}
void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
DecomposeFusedAdd(graph->block());
EliminateIdentityMulAdd(graph->block());
EliminateDeadCode(graph);
// ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
ConvertScalarToTensor(graph->block());
}
} // namespace torch::jit::fuser::onednn
|