File: interface.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (172 lines) | stat: -rw-r--r-- 5,991 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
#include <torch/csrc/jit/codegen/onednn/interface.h>
#include <torch/csrc/jit/codegen/onednn/kernel.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/operator_options.h>

namespace torch {
namespace jit {
namespace fuser {
namespace onednn {

void fuseGraph(std::shared_ptr<Graph>& g) {
  // Follow the process of the tensorexpr_fuser in profiling mode:
  // Remove prim::profile nodes and embed the profile info directly in the
  // IR in value types to avoid breaking the fusion patterns.
  // Will add shape guard after LLGA optimization passes and
  // wipe the tensor type information from the IR, so that it's not
  // accidentally used by any other pass.

  // We rely on the shape specialization and shape guard to ensure the validity
  // of the cached compilation in the kernel, thus only support profiling mode.
  // TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
  // to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
  if (getProfilingMode()) {
    GRAPH_DUMP(
        "Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
        "optimization pass",
        g);
    RemoveProfileNodesAndSpecializeTypes(g);
    GRAPH_DUMP(
        "After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
        g);

    RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
      static std::unordered_set<Symbol> supportedOps = {
          aten::add_,
          aten::mul_,
          aten::tanh_,
          aten::elu_,
          aten::relu_,
          aten::relu6_,
          aten::gelu_,
          aten::sqrt_,
          aten::sigmoid_,
          aten::hardtanh_,
          aten::abs_,
          aten::square_,
      };
      return supportedOps.count(nodeToFunctionalize->kind()) != 0;
    });
    RemoveListMutation(g);
    GRAPH_DUMP("After mutation removal. Before PrepareBinaryForLLGA", g);
    PrepareBinaryForLLGA(g);
    GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
    DeferSizeCheck(g);
    GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
    CreateLlgaSubgraphs(g);
    GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
    PropagateLayout(g);
    GRAPH_DUMP(
        "After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);

    // Add shape guard for profiling mode and wipe the tensor type information
    // from the IR
    prepareFusionGroupAndGuardOutputs(g->block());
    GRAPH_DUMP(
        "After prepareFusionGroupAndGuardOutputs. Before "
        "RemoveTensorTypeSpecializations",
        g);
    RemoveTensorTypeSpecializations(g);
    GRAPH_DUMP(
        "After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
        g);
  }
}

} // namespace onednn
} // namespace fuser

Operation createLlgaKernel(const Node* node) {
  auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
  return [kernel](Stack* stack) {
    RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
    kernel->run(*stack);
    return 0;
  };
}

RegisterOperators oneDNNFusionGroupOp({
    torch::jit::Operator(
        prim::oneDNNFusionGroup,
        createLlgaKernel,
        AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});

// Currently, we convert some scalar inputs, such as the second argument of
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
// But if we have any scalar inputs to guard in the future, some logic here
// would have to be changed.
Operation createLlgaGuardKernel(const Node* node) {
  return [node](Stack* stack) {
#ifdef GRAPH_DEBUG_ENABLED
    GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
#endif
    std::vector<TypePtr> types = node->tys(attr::types);
    const auto num_inputs = types.size();
#ifdef GRAPH_DEBUG_ENABLED
    GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
#endif
    for (size_t i = 0; i < num_inputs; i++) {
#ifdef GRAPH_DEBUG_ENABLED
      GRAPH_DEBUG("checking input ", i);
#endif
      auto& input = peek(stack, i, num_inputs);
      const c10::TensorTypePtr& guard_tensor_type =
          types[i]->cast<TensorType>();

      if (!input.isTensor()) {
#ifdef GRAPH_DEBUG_ENABLED
        GRAPH_DEBUG("input ", i, " is not a tensor, return false");
#endif
        push(stack, IValue(false));
        return;
      }
      const at::Tensor& tensor = input.toTensor();

      // If input tensor is of mkldnn, it's originated from an upstream
      // LLGA partition that has passed the check on input shapes.
      // It is valid to continue here as long as the output shapes from
      // oneDNN graph partitions are determined by the input shapes.
      if (tensor.is_mkldnn()) {
#ifdef GRAPH_DEBUG_ENABLED
        GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
#endif
        continue;
      }

      if (!guard_tensor_type->matchTensor(tensor)) {
#ifdef GRAPH_DEBUG_ENABLED
        GRAPH_DEBUG("input ", i, " check failed, return false");
#endif
        push(stack, IValue(false));
        return;
      }
    }
#ifdef GRAPH_DEBUG_ENABLED
    GRAPH_DEBUG("all check done, return true");
#endif
    push(stack, IValue(true));
    return;
  };
}

RegisterOperators oneDNNGuardOp({
    torch::jit::Operator(
        prim::oneDNNFusionGuard,
        createLlgaGuardKernel,
        AliasAnalysisKind::FROM_SCHEMA),
});
} // namespace jit
} // namespace torch