File: operator.h

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (146 lines) | stat: -rw-r--r-- 3,922 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#pragma once

#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/LlgaTensorImpl.h>
#include <torch/csrc/jit/ir/ir.h>

namespace torch::jit::fuser::onednn {

class Operator {
 public:
  Operator(const Node* node, dnnl::graph::op::kind kind)
      : n(node), o(getId(node), kind, node->kind().toQualString()), k(kind) {}

  // Returns output index if the Value is a graph output.
  // Otherwise returns -1
  int32_t graphOutputIdx(Value* v) {
    int32_t i = 0;
    for (const Value* output : v->owningGraph()->outputs()) {
      if (v == output) {
        return i;
      }
      i++;
    }
    return -1;
  }

  Operator& setInputValue(Value* v) {
    if (v->mustNotBeNone()) {
      if (v->type()->kind() == c10::TensorType::Kind) {
        o.add_input(createLogicalTensor(v));
      }
    }
    return *this;
  }

  Operator& setInput(size_t offset) {
    return setInputValue(n->input(offset));
  }

  template <typename... Ts>
  Operator& setInput(size_t offset, Ts... other) {
    setInput(offset);
    return setInput(other...);
  }

  Operator& setOutputValue(Value* v) {
    if (v->mustNotBeNone()) {
      o.add_output(createLogicalTensor(v));
    }
    return *this;
  }

  // setOutputValue & setOutput require a pointer to the LLGA graph, as output
  // logical tensors that are graph outputs should be connected to an End LLGA
  // op. A value of NULL can be provided for the graph pointer in order to
  // maintain the legacy functionality of this function.
  Operator& setOutputValue(Value* v, std::unique_ptr<dnnl::graph::graph>& g) {
    if (v->mustNotBeNone()) {
      auto output_tensor = createLogicalTensor(v);
      o.add_output(output_tensor);
      if (g) {
        int32_t outputIndex = graphOutputIdx(v);
        if (outputIndex != -1) {
          dnnl::graph::op newEndNode(
              LONG_MAX - outputIndex,
              dnnl::graph::op::kind::End,
              "EndNodeForGraphOutput");
          newEndNode.add_input(output_tensor);
          g->add_op(newEndNode);
        }
      }
    }
    return *this;
  }

  Operator& setOutput(std::unique_ptr<dnnl::graph::graph>& g, size_t offset) {
    return setOutputValue(n->output(offset), g);
  }

  Operator& setOutput(size_t offset) {
    return setOutputValue(n->output(offset));
  }

  template <typename... Ts>
  Operator& setOutput(
      std::unique_ptr<dnnl::graph::graph>& g,
      size_t offset,
      Ts... other) {
    setOutput(g, offset);
    return setOutput(g, other...);
  }

  template <typename Attr>
  Operator& setAttr(dnnl::graph::op::attr name, Attr&& attr) {
    o.set_attr(name, std::forward<Attr>(attr));
    return *this;
  }

  template <typename F>
  Operator& setAttr(dnnl::graph::op::attr name, const F& fn, size_t offset) {
    return setAttr(name, fn(n, offset));
  }

  static float ScalarToFloat(const Node* node, size_t offset) {
    return toIValue(node->input(offset))->toScalar().to<float>();
  }

  static std::vector<int64_t> Ints(const Node* node, size_t offset) {
    return toIValue(node->input(offset))->toIntVector();
  }

  static int64_t Int(const Node* node, size_t offset) {
    return toIValue(node->input(offset))->toInt();
  }

  static float Float(const Node* node, size_t offset) {
    return static_cast<float>(toIValue(node->input(offset))->toDouble());
  }

  static bool Bool(const Node* node, size_t offset) {
    return toIValue(node->input(offset))->toBool();
  }

  static uint64_t getId(const Node* node) {
    return reinterpret_cast<uint64_t>(node); // cast node address as op id
  }

  dnnl::graph::op::kind kind() const {
    return k;
  }

  dnnl::graph::op llgaOp() const {
    return o;
  }

 private:
  dnnl::graph::logical_tensor createLogicalTensor(Value* value) const {
    return LlgaTensorDesc(value).logical_tensor();
  }

  const Node* n;
  dnnl::graph::op o;
  dnnl::graph::op::kind k;
};

} // namespace torch::jit::fuser::onednn