File: register_c10_ops.cpp

package info (click to toggle)
pytorch 2.6.0%2Bdfsg-9
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 161,948 kB
  • sloc: python: 1,278,832; cpp: 900,333; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (63 lines) | stat: -rw-r--r-- 2,034 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <ATen/core/ATenOpList.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/record_function.h>
#include <torch/csrc/jit/frontend/tracer.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/operator.h>

namespace torch::jit {

namespace {

Operator createOperatorFromC10(const c10::OperatorHandle& op) {
  return Operator(op, [op](Stack& stack) { op.callBoxed(stack); });
}

class RegistrationListener final : public c10::OpRegistrationListener {
 public:
  void onOperatorRegistered(const c10::OperatorHandle& op) override {
    if (op.schema().name() == "aten::backward") {
      // aten::backward has a manual wrapper in register_prim_ops_fulljit.cpp.
      // We should not additionally export the c10 aten::backward op from
      // native_functions.yaml to JIT. This special handling is needed because
      // aten::backward requires AliasAnalysisKind::CONSERVATIVE but all ops
      // from native_functions.yaml get AliasAnalysisKind::FROM_SCHEMA.
      // TODO Find a better way to handle this.
      return;
    }
    torch::jit::registerOperator(createOperatorFromC10(op));
  }

  void onOperatorDeregistered(const c10::OperatorHandle& op) override {
    if (op.schema().name() == "aten::backward") {
      // see comment in onOperatorRegistered for why aten::backward is excluded
      return;
    }
    torch::jit::deregisterOperator(op.schema());
  }
};

struct Registerer final {
  // this immediately calls the listener on all existing ops,
  // and calls it in future whenever a new op is registered
  Registerer()
      : listenerRAII(c10::Dispatcher::singleton().addRegistrationListener(
            std::make_unique<RegistrationListener>())) {}
  c10::RegistrationHandleRAII listenerRAII;
};

Registerer& registerer() {
  static Registerer registerer;
  return registerer;
}

// global instance to run its constructor on startup
[[maybe_unused]] Registerer& dummy = registerer();

} // namespace

void ensure_c10_registerer_defined() {
  registerer();
}

} // namespace torch::jit