File: prepare_binary.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (179 lines) | stat: -rw-r--r-- 7,304 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
#include <aten/src/ATen/core/jit_type.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/shape_analysis.h>

namespace torch::jit::fuser::onednn {

static bool compareConstValue(Value* v, double d) {
  auto ival = toIValue(v);
  return ival.has_value() &&
      ((ival->isInt() && static_cast<int>(ival->toInt()) == d) ||
       (ival->isDouble() && ival->toDouble() == d));
}

static void handleBinaryOpInputs(Node* node) {
  // We do not handle binary ops with two scalar inputs,
  // and we assume scalar is always at the second place.
  if (node->input(0)->type()->isSubtypeOf(TensorType::get())) {
    auto dtypeOfFirstInput =
        node->input(0)->type()->cast<TensorType>()->scalarType().value();
    if (node->input(1)->type()->isSubtypeOf(FloatType::get()) ||
        node->input(1)->type()->isSubtypeOf(IntType::get())) {
      // If a scalar is added to be a tensor, we would assume that the
      // scalar is of the same dtype as the tensor, as oneDNN graph
      // currently requires inputs of binary ops to have the same dtype.
      // We create a 1D tensor from the scalar input & "promote" its
      // dtype to that of the first input. Doing so helps us satisfy PyTorch's
      // type promotion rules.
      // Although we convert the scalar to a tensor, we still need to promote
      // types, as if the second input were still a scalar.
      // The following sample code-snippet illustrates that converting a scalar
      // input to a 1-D tensor may result in a different output dtype than would
      // otherwise have been the case.
      // clang-format off
      //   >>> (1. + torch.rand([2]).half()).dtype
      //       torch.float16
      //   >>> (torch.tensor(1.).unsqueeze(0) + (torch.rand([2]).half())).dtype
      //       torch.float32
      // clang-format on
      auto promotedDtype = dtypeOfFirstInput;
      auto scalar = node->input(1);
      WithInsertPoint guard(node);
      auto g = node->owningGraph();
      // 42 : Scalar  -->  tensor(42.0) : Float([])
      auto t = g->insert(aten::as_tensor, {scalar}, {{"dtype", promotedDtype}});
      // add dim & stride info to IR
      std::optional<size_t> t_dim = 1;
      auto target_type = TensorTypePtr(
          TensorType::create(promotedDtype, at::kCPU, t_dim, false));
      target_type = target_type->withSizes({1});
      t->setType(target_type);

      // tensor(42.0) : Float([])  -->  tensor([42.0]) : Float([1])
      auto unsqueezed = g->insert(aten::unsqueeze, {t, 0});
      unsqueezed->setType(target_type);
      node->replaceInput(1, unsqueezed);

      // dtype might have changed, so needs to be updated in IR as well
      node->output()->setType(
          node->output()->type()->expect<TensorType>()->withScalarType(
              promotedDtype));
    } else if (node->input(1)->type()->isSubtypeOf(TensorType::get())) {
      // Here, both inputs are tensors, and we just wanna make sure that they
      // are the same dtype, as oneDNN Graph requires both inputs to have the
      // same dtype. We'll follow PyTorch's type-promotion rules here.
      auto second_input_typeptr = node->input(1)->type()->expect<TensorType>();
      std::optional<at::ScalarType> second_input_type =
          second_input_typeptr->scalarType();
      if (second_input_type != std::nullopt) {
        // dtype of the second tensor might not be available in the IR
        auto dtypeOfSecondInput = second_input_type.value();
        if (dtypeOfFirstInput != dtypeOfSecondInput) {
          // Type promotion is required
          auto promotedDtype =
              c10::promoteTypes(dtypeOfFirstInput, dtypeOfSecondInput);
          WithInsertPoint guard(node);
          auto g = node->owningGraph();
          if (promotedDtype == dtypeOfFirstInput) {
            auto to_node_output = g->insert(
                aten::to, {node->input(1)}, {{"dtype", promotedDtype}});
            to_node_output->setType(
                node->input(1)->type()->expect<TensorType>()->withScalarType(
                    promotedDtype));
            node->replaceInput(1, to_node_output);
          } else {
            auto to_node_output = g->insert(
                aten::to, {node->input(0)}, {{"dtype", promotedDtype}});
            to_node_output->setType(
                node->input(0)->type()->expect<TensorType>()->withScalarType(
                    promotedDtype));
            node->replaceInput(0, to_node_output);
          }
          // dtype might have changed, so needs to be updated in IR as well
          node->output()->setType(
              node->output()->type()->expect<TensorType>()->withScalarType(
                  promotedDtype));
        } else {
          // both dtypes are same
          // IR info of dtypes is missing sometimes in JIT IR,
          // and we shouldn't treat those tensors as FP32 tensors by default.
          node->output()->setType(
              node->output()->type()->expect<TensorType>()->withScalarType(
                  dtypeOfFirstInput));
        }
      } // end inner if block
    } // end outer if block
  }
}

static void ConvertScalarToTensor(Block* block) {
  for (auto node : block->nodes()) {
    for (auto sub : node->blocks()) {
      ConvertScalarToTensor(sub);
    }

    if (node->kind() == aten::add || node->kind() == aten::mul ||
        node->kind() == aten::div) {
      handleBinaryOpInputs(node);
    }
  }
}

static void mayDecomposeAdd(Node* node) {
  if (node->inputs().size() < 3) {
    return; // corner-case in BERT-mrpc that's not in line with
            // native_functions.yaml
  }
  if (toIValue(node->namedInput("alpha")).has_value()) {
    auto alphaEqualsOne = compareConstValue(node->namedInput("alpha"), 1.0);
    if (!alphaEqualsOne) {
      WithInsertPoint guard(node);
      auto g = node->owningGraph();
      auto mul = g->insert(
          aten::mul, {node->namedInput("other"), node->namedInput("alpha")});
      if (node->namedInput("other")->type()->isSubtypeOf(TensorType::get())) {
        auto mulTensorTypePtr = node->namedInput("other")->type();
        mul->setType(mulTensorTypePtr);
      }
      node->replaceInput(1, mul);
      auto one = g->insertConstant(1.0);
      node->replaceInput(2, one);
    }
  }
}

static void DecomposeFusedAdd(Block* block) {
  for (auto node : block->nodes()) {
    for (auto sub : node->blocks()) {
      DecomposeFusedAdd(sub);
    }

    if (node->kind() == aten::add) {
      mayDecomposeAdd(node);
    }
  }
}

static void EliminateIdentityMulAdd(Block* block) {
  for (auto node : block->nodes()) {
    for (auto sub : node->blocks()) {
      EliminateIdentityMulAdd(sub);
    }

    if ((node->kind() == aten::add && compareConstValue(node->input(1), 0.0)) ||
        (node->kind() == aten::mul && compareConstValue(node->input(1), 1.0))) {
      node->output()->replaceAllUsesWith(node->namedInput("self"));
    }
  }
}

void PrepareBinaryForLLGA(const std::shared_ptr<Graph>& graph) {
  DecomposeFusedAdd(graph->block());
  EliminateIdentityMulAdd(graph->block());
  EliminateDeadCode(graph);
  // ConvertScalarToTensor must be placed after EliminateIdentityMulAdd
  ConvertScalarToTensor(graph->block());
}

} // namespace torch::jit::fuser::onednn