File: python_dispatch.cpp

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (1028 lines) | stat: -rw-r--r-- 37,576 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <torch/csrc/utils/python_dispatch.h>

#include <ATen/ATen.h>
#include <ATen/FuncTorchTLS.h>
#include <ATen/FunctionalTensorWrapper.h>
#include <ATen/TensorSubclassLikeUtils.h>
#include <ATen/core/NestedIntSymNodeImpl.h>
#include <ATen/core/PythonOpRegistrationTrampoline.h>
#include <ATen/core/dispatch/Dispatcher.h>

#include <ATen/functorch/BatchedTensorImpl.h>
#include <torch/library.h>

#include <c10/core/SafePyObject.h>
#include <torch/csrc/PyInterpreter.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/tensor_new.h>

#include <c10/util/flat_hash_map.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/utils/python_raii.h>

#include <iostream>
#include <utility>

namespace py = pybind11;

namespace torch::impl::dispatch {

// NB: I'd like to index this on OperatorHandle, but I can't, as I can't
// guarantee that the main interpreter has finish doing all registrations before
// the other interpreters start banging on it
static ska::flat_hash_map<
    c10::OperatorName,
    ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
    python_registrations_;

static torch::Library::Kind parseKind(const std::string& k) {
  static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
      {"DEF", torch::Library::DEF},
      {"IMPL", torch::Library::IMPL},
      {"FRAGMENT", torch::Library::FRAGMENT},
  };
  auto it = kind_map.find(k);
  TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
  return it->second;
}
static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
  static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
      {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
      {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
      {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
      {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
  };
  auto it = key_map.find(k);
  TORCH_CHECK(it != key_map.end(), "could not parse ", k);
  return it->second;
}

template <typename Func>
inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
  if (key[0] != '\0') {
    return torch::dispatch(
        c10::parseDispatchKey(key), std::forward<Func>(raw_f));
  } else {
    torch::CppFunction f(std::forward<Func>(raw_f));
    return f;
  }
}

struct EnableHermeticPyObject {
  EnableHermeticPyObject()
      : old_(c10::impl::HermeticPyObjectTLS::get_state()),
        old_excluded_python_(
            c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
        old_python_(
            c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
        old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
            at::DispatchKey::PythonTLSSnapshot)) {
    c10::impl::HermeticPyObjectTLS::set_state(true);
    c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
    c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
    c10::impl::tls_set_dispatch_key_included(
        at::DispatchKey::PythonTLSSnapshot, false);
  }
  ~EnableHermeticPyObject() {
    c10::impl::HermeticPyObjectTLS::set_state(old_);
    c10::impl::tls_set_dispatch_key_excluded(
        at::DispatchKey::Python, old_excluded_python_);
    c10::impl::tls_set_dispatch_key_included(
        at::DispatchKey::Python, old_python_);
    c10::impl::tls_set_dispatch_key_included(
        at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
  }
  EnableHermeticPyObject(const EnableHermeticPyObject&) = delete;
  EnableHermeticPyObject(EnableHermeticPyObject&&) = delete;
  EnableHermeticPyObject& operator=(const EnableHermeticPyObject&) = delete;
  EnableHermeticPyObject& operator=(EnableHermeticPyObject&&) = delete;
  bool old_;
  bool old_excluded_python_;
  bool old_python_;
  bool old_python_snapshot_;
};

class PythonKernelHolder : public c10::OperatorKernel {
  c10::SafePyObject func_;
  c10::DispatchKey dispatch_key_;
  // If "with_keyset", then we expect a keyset as the first arg.
  bool with_keyset_;
  // If "with_op", then we expect the op as first arg (or second if keyset)
  bool with_op_;

 public:
  PythonKernelHolder(
      py::object func,
      c10::DispatchKey dispatch_key,
      bool with_keyset = false,
      bool with_op = false)
      : func_(func.release().ptr(), getPyInterpreter()),
        dispatch_key_(dispatch_key),
        with_keyset_(with_keyset),
        with_op_(with_op) {}

  void operator()(
      const c10::OperatorHandle& op,
      c10::DispatchKeySet keyset,
      torch::jit::Stack* stack) {
    // Figure out if we can handle it hermetically, or if we have
    // to double dispatch

    // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
    const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
    if (mode_stack_len > 0) {
      const auto& cur_torch_dispatch_mode_state =
          c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
      cur_torch_dispatch_mode_state->pyinterpreter()
          ->python_op_registration_trampoline(
              op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
      return;
    }

    const auto& schema = op.schema();
    const auto num_arguments = schema.arguments().size();

    // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
    // means it's a nontrivial tensor subclass)
    for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
      if (ivalue.isTensor()) {
        auto* interpreter =
            ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
        if (interpreter &&
            ivalue.unsafeToTensorImpl()->key_set().has(
                at::DispatchKey::Python)) {
          (*interpreter)
              ->python_op_registration_trampoline(
                  op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
          return;
        }
      } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
        // NB: use toListRef as it doesn't induce refcount bumps
        // (toTensorListRef is not a thing)
        for (const auto& nv : ivalue.toListRef()) {
          if (nv.isNone()) {
            continue;
          }
          auto* interpreter =
              nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
          if (interpreter &&
              nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
            (*interpreter)
                ->python_op_registration_trampoline(
                    op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
            return;
          }
        }
      }
    }

    // Nothing requires the operator to be homed to a specific interpreter, so
    // run it on the current interpreter

    auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
    py::gil_scoped_acquire g;
    // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
    // mode unconditionally in all situations when you're using multipy.
    // Eventually just delete this entirely.  (Note that you may break multipy
    // anyway this way with dispatcher registered functions that require
    // hermetic to be off.)
#if defined(USE_DEPLOY)
    EnableHermeticPyObject g2;
#endif
    auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
    auto func =
        py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
    auto obj = with_op_ ? with_keyset_
            ? func(
                  keyset,
                  torch::detail::getTorchApiFunction(op),
                  *args_kwargs.first,
                  **args_kwargs.second)
            : func(
                  torch::detail::getTorchApiFunction(op),
                  *args_kwargs.first,
                  **args_kwargs.second)
        : with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
                        : func(*args_kwargs.first, **args_kwargs.second);
    if (!obj) {
      throw python_error();
    }
    pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
  }
};

static torch::_RegisterOrVerify register_or_verify() {
  if (isMainPyInterpreter()) {
    return torch::_RegisterOrVerify::REGISTER;
  } else {
    return torch::_RegisterOrVerify::VERIFY;
  }
}

static py::object ophandle_call_boxed(
    const c10::OperatorHandle& handle,
    const py::args& args,
    const py::kwargs& kwargs) {
  auto stack = torch::jit::createStackForSchema(
      handle.schema(),
      args,
      kwargs,
      /*self=*/std::nullopt);
  {
    pybind11::gil_scoped_release no_gil_guard;
    handle.callBoxed(stack);
  }
  return torch::jit::createPyObjectForStack(std::move(stack));
}

// A small RAII guard that lets you explicitly *remove* a key from the TLS
// exclude set.
class SetExcludeDispatchKeyGuard {
 public:
  SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
      : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
    c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
  }
  ~SetExcludeDispatchKeyGuard() {
    c10::impl::tls_set_dispatch_key_excluded(k, old);
  }
  SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
  SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
      delete;
  SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
  SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;

 private:
  at::DispatchKey k;
  bool old;
};

void initDispatchBindings(PyObject* module) {
  auto m = py::handle(module).cast<py::module>();

  py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
      .def("schema", &c10::OperatorHandle::schema)
      .def("debug", &c10::OperatorHandle::debug)
      .def(
          "redispatch_boxed",
          [](const py::object& self,
             c10::DispatchKeySet keyset,
             py::args args,
             const py::kwargs& kwargs) {
            auto& handle = self.cast<c10::OperatorHandle&>();
            auto stack = torch::jit::createStackForSchema(
                handle.schema(),
                std::move(args),
                kwargs,
                /*self=*/std::nullopt);
            {
              pybind11::gil_scoped_release no_gil_guard;
              handle.redispatchBoxed(keyset, &stack);
            }
            return torch::jit::createPyObjectForStack(std::move(stack));
          });

  m.def("_dispatch_call_boxed", &ophandle_call_boxed);

  // TODO: figure out how to do chaining
  py::class_<torch::Library>(m, "_DispatchModule")
      .def(
          "reset",
          [](const py::object& self) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().reset();
            return;
          },
          "")
      // Some of these APIs are only for testing and do not work in multipy
      // environment
      .def(
          "def_",
          [](py::object self, const char* schema, const char* alias) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().def(
                torch::schema(schema, parseAliasAnalysisKind(alias)));
            return self;
          },
          "",
          py::arg("schema"),
          py::arg("alias") = "")
      // Simulated "legacy" def where alias analysis kind is not set.
      // Ordinarily this can only be exercised from RegisterOperators() API
      // but I am not going to bind that here
      .def(
          "def_legacy",
          [](py::object self, const char* schema) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
            return self;
          },
          "",
          py::arg("schema"))
      // We can't conveniently turn Python functions into valid functions
      // in the dispatcher.  So instead we provide a bunch of precanned
      // functions for testing purposes.  You're NOT intended to actually
      // call these functions; they're just here so we can actually register
      // something
      //
      // Mangling scheme: args_rets.  One character per.
      //  t = Tensor
      .def(
          "def_name_t_t",
          [](py::object self,
             const char* name,
             const char* dispatch,
             const char* debug) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().def(
                name, dispatch_str(dispatch, [](const at::Tensor& a) {
                        return a;
                      }).debug(debug));
            return self;
          },
          "",
          py::arg("name"),
          py::arg("dispatch") = "",
          py::arg("debug") = "default_def_name_t_t")
      .def(
          "def_schema_t_t",
          [](py::object self,
             const char* schema,
             const char* dispatch,
             const char* alias,
             const char* debug) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().def(
                torch::schema(schema, parseAliasAnalysisKind(alias)),
                dispatch_str(dispatch, [](const at::Tensor& a) {
                  return a;
                }).debug(debug));
            return self;
          },
          "",
          py::arg("name"),
          py::arg("dispatch") = "",
          py::arg("alias") = "",
          py::arg("debug") = "default_def_schema_t_t")
      // TODO: maybe consider deduplicating the definitions here, it's getting
      // pretty long
      .def(
          "impl_t_t",
          [](py::object self,
             const char* name,
             const char* dispatch,
             const char* debug) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().impl(
                name, dispatch_str(dispatch, [](const at::Tensor& a) {
                        return a;
                      }).debug(debug));
            return self;
          },
          "",
          py::arg("name"),
          py::arg("dispatch") = "",
          py::arg("debug") = "impl_t_t")
      .def(
          "impl_with_aoti_compile",
          [](const py::object& self,
             const char* ns,
             const char* op_name_with_overload,
             c10::DispatchKey dispatch) {
            HANDLE_TH_ERRORS
            std::string reg_op_name =
                std::string(ns).append("::").append(op_name_with_overload);

            auto& lib = self.cast<torch::Library&>();
            lib.impl(
                reg_op_name.c_str(),
                torch::dispatch(
                    dispatch,
                    CppFunction::makeFromBoxedFunctor(
                        std::make_unique<
                            torch::inductor::AOTIPythonKernelHolder>(
                            dispatch, ns, op_name_with_overload))),
                register_or_verify());
            END_HANDLE_TH_ERRORS_PYBIND
          },
          "",
          py::arg("ns"),
          py::arg("op_name_with_overload"),
          py::arg("dispatch"))
      .def(
          "impl",
          [](const py::object& self,
             const char* name,
             // TODO: empty string no longer works
             c10::DispatchKey dispatch,
             py::object func,
             bool with_keyset) {
            HANDLE_TH_ERRORS
            auto& lib = self.cast<torch::Library&>();
            if (func.is(py::module::import("torch.library")
                            .attr("fallthrough_kernel"))) {
              lib.impl(
                  name,
                  torch::dispatch(dispatch, CppFunction::makeFallthrough()),
                  register_or_verify());
            } else {
              lib.impl(
                  name,
                  torch::dispatch(
                      dispatch,
                      CppFunction::makeFromBoxedFunctor(
                          std::make_unique<PythonKernelHolder>(
                              func, dispatch, with_keyset))),
                  register_or_verify());
              python_registrations_[lib._resolve(name)].insert_or_assign(
                  dispatch,
                  std::make_shared<c10::SafePyObject>(
                      func.release().ptr(), getPyInterpreter()));
            }
            END_HANDLE_TH_ERRORS_PYBIND
          },
          "",
          py::arg("name"),
          py::arg("dispatch"),
          py::arg("func"),
          py::arg("with_keyset") = false)
      .def(
          "define",
          [](const py::object& self,
             const char* schema,
             const char* alias_analysis,
             const std::vector<at::Tag>& tags) {
            auto parsed_schema =
                torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
            self.cast<torch::Library&>().def(
                std::move(parsed_schema), tags, register_or_verify());
            // TODO: this is dumb, had to make a second copy
            return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
                .name();
          },
          "",
          py::arg("schema"),
          py::arg("alias_analysis") = "",
          py::arg("tags") = std::vector<at::Tag>())
      .def(
          "fallback_fallthrough",
          [](py::object self, const char* dispatch) {
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            self.cast<torch::Library&>().fallback(
                dispatch_str(dispatch, CppFunction::makeFallthrough()));
            return self;
          },
          "",
          py::arg("dispatch") = "")
      .def(
          "fallback",
          [](const py::object& self,
             c10::DispatchKey dispatch,
             const py::object& func,
             bool with_keyset) {
            HANDLE_TH_ERRORS
            auto& lib = self.cast<torch::Library&>();
            TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
            if (func.is(py::module::import("torch.library")
                            .attr("fallthrough_kernel"))) {
              lib.fallback(
                  torch::dispatch(dispatch, CppFunction::makeFallthrough()));
            } else {
              lib.fallback(torch::dispatch(
                  dispatch,
                  CppFunction::makeFromBoxedFunctor(
                      std::make_unique<PythonKernelHolder>(
                          func, dispatch, with_keyset, /*with_op*/ true))));
            }
            END_HANDLE_TH_ERRORS_PYBIND
          },
          "",
          py::arg("dispatch"),
          py::arg("func"),
          py::arg("with_keyset") = false);

  m.def(
      "_dispatch_library",
      [](const char* kind,
         std::string name,
         const char* dispatch,
         const char* file,
         uint32_t linenum) {
        HANDLE_TH_ERRORS
        return std::make_unique<torch::Library>(
            parseKind(kind),
            std::move(name),
            std::string(dispatch).empty()
                ? std::nullopt
                : std::make_optional(c10::parseDispatchKey(dispatch)),
            "/dev/null", // temporary workaround
            linenum);
        END_HANDLE_TH_ERRORS_PYBIND
      },
      "",
      py::arg("kind"),
      py::arg("name"),
      py::arg("dispatch"),
      py::arg("file") = "/dev/null",
      py::arg("linenum") = 0);

  m.def(
      "_dispatch_find_schema_or_throw",
      [](const char* name, const char* overload_name) -> c10::OperatorHandle {
        return c10::Dispatcher::singleton().findSchemaOrThrow(
            name, overload_name);
      });

  m.def("_dispatch_dump", [](const char* name) -> std::string {
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
    if (!op) {
      return "";
    } else {
      return op->dumpState();
    }
  });

  m.def("_dispatch_dump_table", [](const char* name) -> std::string {
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
    if (!op) {
      return "";
    } else {
      return op->dumpComputedTable();
    }
  });

  m.def("_dispatch_check_invariants", [](const char* name) {
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
    if (!op) {
    } else {
      return op->checkInvariants();
    }
  });

  m.def("_dispatch_check_all_invariants", []() {
    c10::Dispatcher::singleton().checkInvariants();
  });

  m.def("_dispatch_has_kernel", [](const char* name) -> bool {
    auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
    return static_cast<bool>(op);
  });

  m.def(
      // Returns whether or not a direct kernel registration exists
      // for this <op_name, dispatch_key> pair.
      "_dispatch_has_kernel_for_dispatch_key",
      [](const char* name, c10::DispatchKey dispatch) -> bool {
        auto op =
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
        TORCH_CHECK(op, "operator ", name, " does not exist");
        return op->hasKernelForDispatchKey(dispatch);
      });

  m.def(
      // Returns whether or not the kernel for this dispatach key is a
      // fallthrough kernel
      "_dispatch_kernel_for_dispatch_key_is_fallthrough",
      [](const char* name, c10::DispatchKey dispatch) -> bool {
        auto op =
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
        return op->isKernelFallthroughKernel(dispatch);
      });

  m.def(
      "_dispatch_has_kernel_for_any_dispatch_key",
      [](const char* name, c10::DispatchKeySet ks) -> bool {
        auto op =
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
        TORCH_CHECK(op, "operator ", name, " does not exist");
        return op->hasKernelForAnyDispatchKey(ks);
      });

  m.def(
      // Returns whether or not there is an entry in the runtime computed
      // dispatch table, for this <op_name, dispatch_key> pair. For example, if
      // "op" has a `CompositeImplicitAutograd` kernel, Then
      // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
      // true for all backends that are part of the alias set for
      // CompositeImplicitAutograd.
      "_dispatch_has_computed_kernel_for_dispatch_key",
      [](const char* name, const char* dispatch) -> bool {
        auto op =
            c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
        TORCH_CHECK(op, "operator ", name, " does not exist");
        return op->hasComputedKernelForDispatchKey(
            c10::parseDispatchKey(dispatch));
      });

  m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
    auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();

    std::vector<std::string> states;
    states.reserve(danglingImpls.size());
    for (auto& danglingImpl : danglingImpls) {
      states.emplace_back(danglingImpl.dumpState());
    }

    return states;
  });

  m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
    auto op_names = c10::Dispatcher::singleton().getAllOpNames();

    std::vector<std::string> names;
    names.reserve(op_names.size());
    for (auto& op : op_names) {
      std::stringstream ss;
      ss << op.name;
      if (!op.overload_name.empty()) {
        ss << "." << op.overload_name;
      }
      names.emplace_back(std::move(ss).str());
    }

    return names;
  });

  m.def(
      "_dispatch_tls_set_dispatch_key_excluded",
      [](c10::DispatchKey dispatch_key, bool desired_state) {
        c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
      });
  m.def(
      "_dispatch_tls_is_dispatch_key_excluded",
      [](c10::DispatchKey dispatch_key) {
        return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
      });
  m.def(
      "_dispatch_tls_set_dispatch_key_included",
      [](c10::DispatchKey dispatch_key, bool desired_state) {
        c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
      });
  m.def(
      "_dispatch_tls_is_dispatch_key_included",
      [](c10::DispatchKey dispatch_key) {
        return c10::impl::tls_is_dispatch_key_included(dispatch_key);
      });

  m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
    return at::isTensorSubclassLike(tensor);
  });

  m.def("_dispatch_key_name", [](c10::DispatchKey k) {
    return c10::toString(k);
  });
  m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
  m.def("_to_functionality_key", [](c10::DispatchKey k) {
    return c10::toFunctionalityKey(k);
  });
  // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
  //  AutogradCPU
  //  AutogradCUDA
  //  ...
  //  AutogradPrivateUse3
  m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
    std::vector<c10::DispatchKey> keys;
    if (c10::isPerBackendFunctionalityKey(key)) {
      auto ks = c10::DispatchKeySet(key) |
          c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
      for (auto k : ks) {
        keys.push_back(k);
      }
    } else {
      keys.push_back(key);
    }
    return keys;
  });
  m.def("_dispatch_num_backends", []() { return c10::num_backends; });

#define DEF_ONE(n) .value(#n, c10::DispatchKey::n)

  py::enum_<c10::DispatchKey>(m, "DispatchKey")
      // clang-format off
      DEF_ONE(Undefined)
      DEF_ONE(CompositeExplicitAutogradNonFunctional)
      DEF_ONE(CompositeExplicitAutograd)
      DEF_ONE(CompositeImplicitAutogradNestedTensor)
      DEF_ONE(CompositeImplicitAutograd)
      // NestedTensor is not a backend key
      DEF_ONE(AutogradNestedTensor)
      DEF_ONE(AutogradOther)
      DEF_ONE(Autograd)
      DEF_ONE(Conjugate)
      DEF_ONE(ZeroTensor)
      DEF_ONE(Negative)
      DEF_ONE(BackendSelect)
      DEF_ONE(ADInplaceOrView)
      DEF_ONE(PythonTLSSnapshot)
      DEF_ONE(Python)
      DEF_ONE(FuncTorchDynamicLayerFrontMode)
      DEF_ONE(FuncTorchDynamicLayerBackMode)
      DEF_ONE(FuncTorchBatchedDecomposition)
      DEF_ONE(FuncTorchBatched)
      DEF_ONE(FuncTorchVmapMode)
      DEF_ONE(FuncTorchGradWrapper)
      DEF_ONE(PythonDispatcher)
      DEF_ONE(PreDispatch)
      DEF_ONE(Functionalize)
      DEF_ONE(AutocastCPU)
      DEF_ONE(AutocastMPS)
      DEF_ONE(AutocastXPU)
      DEF_ONE(AutocastHPU)
      DEF_ONE(AutocastIPU)
      DEF_ONE(AutocastCUDA)
      DEF_ONE(AutocastPrivateUse1)
  // clang-format on

#define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
#define DEF_MULTIPLE(fullname, prefix)              \
  DEF_SINGLE(, fullname)                            \
  DEF_SINGLE(, StartOf##fullname##Backends)         \
  C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
  DEF_SINGLE(, EndOf##fullname##Backends)

      // clang-format off
  C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
  // clang-format on

#undef DEF_MULTIPLE
#undef DEF_SINGLE
          ;

  py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
      .def(py::init<c10::DispatchKey>())
      .def("__or__", &c10::DispatchKeySet::operator|)
      .def("__sub__", &c10::DispatchKeySet::operator-)
      .def("__and__", &c10::DispatchKeySet::operator&)
      .def("raw_repr", &c10::DispatchKeySet::raw_repr)
      .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
      .def(
          "remove",
          [](c10::DispatchKeySet self, c10::DispatchKey k) {
            return self.remove(k);
          })
      .def(
          "add",
          [](c10::DispatchKeySet self, c10::DispatchKey k) {
            return self.add(k);
          })
      .def("has", &c10::DispatchKeySet::has)
      .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });

  m.attr("_dispatch_autogradother_backends") =
      py::cast(c10::autogradother_backends);

  m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
      py::cast(at::functorch::kKeysToPropagateToWrapper);

  m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
  m.attr("_after_ADInplaceOrView_keyset") =
      py::cast(c10::after_ADInplaceOrView_keyset);

  m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
    return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
  });

  m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
    return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
  });

  m.def("_dispatch_keyset_full", []() {
    return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
  });

  m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);

  m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
    return c10::toString(keyset);
  });

  m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
    return c10::getBackendKeySetFromAutograd(k);
  });

  m.def("_dispatch_keys", [](const at::Tensor& tensor) {
    auto* impl = tensor.unsafeGetTensorImpl();
    return impl->key_set();
  });
  m.def("_dispatch_tls_local_include_set", []() {
    return c10::impl::tls_local_dispatch_key_set().included_;
  });
  m.def("_dispatch_tls_local_exclude_set", []() {
    return c10::impl::tls_local_dispatch_key_set().excluded_;
  });
  m.def("_functionalization_reapply_views_tls", []() {
    return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
  });
  m.def(
      "_dispatch_is_included_in_alias",
      [](c10::DispatchKey a, c10::DispatchKey b) {
        return c10::isIncludedInAlias(a, b);
      });

  // DEPRECATED, please don't use this. Instead use
  // torch._C._ExcludeDispatchKeyGuard
  py_context_manager_DEPRECATED<
      c10::impl::ExcludeDispatchKeyGuard,
      c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");

  py_context_manager<
      c10::impl::ForceDispatchKeyGuard,
      c10::DispatchKeySet,
      c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
  py_context_manager<c10::impl::ForceDispatchKeyGuard>(
      m, "_PreserveDispatchKeyGuard");
  py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
      m, "_IncludeDispatchKeyGuard");
  py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
      m, "_ExcludeDispatchKeyGuard");
  py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
      m, "_SetExcludeDispatchKeyGuard");

  py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
      m, "_AutoDispatchBelowAutograd");
  py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
      m, "_AutoDispatchBelowADInplaceOrView");

  // Prints out the name of every operator that has a kernel registered to the
  // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
  // out the name of every operator that the Dispatcher knows of. This can be
  // useful to answer questions like "list all operators that do not have a CPU
  // kernel".
  m.def(
      "_dispatch_print_registrations_for_dispatch_key",
      [](const char* dispatch_key = "") {
        auto k = std::string(dispatch_key).empty()
            ? std::nullopt
            : std::make_optional(c10::parseDispatchKey(dispatch_key));
        auto op_names =
            c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
        for (auto& op : op_names) {
          std::cout << op << '\n';
        }
      },
      py::arg("dispatch_key") = static_cast<const char*>(""));

  m.def(
      "_parse_dispatch_key",
      [](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
        try {
          return c10::parseDispatchKey(dispatch_key);
        } catch (const c10::Error& err) {
          return std::nullopt;
        }
      });

  m.def(
      "_dispatch_get_registrations_for_dispatch_key",
      [](const char* dispatch_key = "") {
        auto k = std::string(dispatch_key).empty()
            ? std::nullopt
            : std::make_optional(c10::parseDispatchKey(dispatch_key));
        auto op_names =
            c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
        std::vector<std::string> names;
        names.reserve(op_names.size());
        for (auto& op : op_names) {
          names.emplace_back(
              op.name +
              (op.overload_name.empty() ? "" : "." + op.overload_name));
        }
        return names;
      },
      py::arg("dispatch_key") = static_cast<const char*>(""));
  m.def(
      "_dispatch_set_report_error_callback",
      [](c10::OperatorHandle& handle, py::object callback) {
        auto obj = callback.release().ptr();
        auto callback_obj =
            std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
        handle.setReportErrorCallback_(std::move(callback_obj));
      });

  m.def(
      "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
  m.def("_dispatch_pystub", [](const char* name, const char* overload) {
    return c10::Dispatcher::singleton().getPyStub(
        c10::OperatorName(name, overload));
  });

  m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
    return at::functionalization::impl::replace_(a, b);
  });
  m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
    at::functionalization::impl::propagate_xla_data(a, b);
  });
  m.def("_commit_update", [](const at::Tensor& a) {
    return at::functionalization::impl::commit_update(a);
  });
  m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
    return at::functionalization::impl::unsafe_reset_storage(a);
  });

  m.def("_dispatch_key_for_device", [](const std::string& device_type) {
    auto device = c10::Device(device_type);
    TORCH_CHECK(
        !device.has_index(),
        "Expected device_type string to not have a device index; got ",
        device_type);
    return c10::toString(
        c10::computeDispatchKey(std::nullopt, std::nullopt, device));
  });

  m.def("_are_functorch_transforms_active", []() {
    auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
    return (
        include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
        include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
  });

  m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
    return c10::SymInt(c10::SymNode(
        c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
  });

  m.def("_get_constant_bool_symnode", [](int64_t data) {
    return c10::SymNode(
        c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
  });

  m.def("_non_sym_sizes", [](const at::Tensor& a) {
    return a.sizes(); // NB: NOT sym_size
  });

  m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
    if (!t.unsafeGetTensorImpl()->has_storage()) {
      // If the Tensor doesn't have a storage, then accessing .data_ptr()
      // will already raise an error.
      return;
    }
    // Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
    // will throw.
    t.unsafeGetTensorImpl()
        ->storage()
        .unsafeGetStorageImpl()
        ->set_throw_on_mutable_data_ptr();
  });

  // Invariant: you must ONLY call this with FakeTensors.
  m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
    if (!t.unsafeGetTensorImpl()->has_storage()) {
      // If the Tensor doesn't have a storage, then accessing .data_ptr()
      // will already raise an error.
      return;
    }
    t.unsafeGetTensorImpl()
        ->storage()
        .unsafeGetStorageImpl()
        ->set_warn_deprecated_on_mutable_data_ptr();
  });

  m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
  m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);

  using c10::impl::TorchDispatchModeKey;
  py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
      .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
      .value("PROXY", TorchDispatchModeKey::PROXY)
      .value("FAKE", TorchDispatchModeKey::FAKE);
}

// TODO: dedupe with the kernel
void python_op_registration_trampoline_impl(
    const c10::OperatorHandle& op,
    c10::DispatchKey key,
    c10::DispatchKeySet keyset,
    torch::jit::Stack* stack,
    bool with_keyset,
    bool with_op) {
  auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
  py::gil_scoped_acquire g;
  auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
  const auto& func = python_registrations_[op.operator_name()][key];
  TORCH_INTERNAL_ASSERT(func != nullptr);
  auto* pyobj = func->ptr(getPyInterpreter());
  TORCH_INTERNAL_ASSERT(pyobj != nullptr);
  auto callable = py::reinterpret_borrow<py::object>(pyobj);
  auto obj = with_op ? with_keyset ? callable(
                                         keyset,
                                         torch::detail::getTorchApiFunction(op),
                                         *args_kwargs.first,
                                         **args_kwargs.second)
                                   : callable(
                                         torch::detail::getTorchApiFunction(op),
                                         *args_kwargs.first,
                                         **args_kwargs.second)
      : with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
                    : callable(*args_kwargs.first, **args_kwargs.second);
  if (!obj) {
    throw python_error();
  }
  pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
}

} // namespace torch::impl::dispatch