File: Graph.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (52 lines) | stat: -rw-r--r-- 1,773 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include <torch/csrc/python_headers.h>

#include <pybind11/chrono.h>

#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/utils/pybind.h>

#include <ATen/cuda/CUDAGraph.h>

// Cargo culted partially from csrc/distributed/c10d/init.cpp
// and partially from csrc/cuda/Stream.cpp.
// THCPStream_init is also declared at global scope.

// Because THCPGraph_init is forward declared in the only consumer
// (csrc/Module.cpp) I don't think we need a Graph.h.

template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;

void THCPGraph_init(PyObject* module) {
  // Pybind11 patch notes say "py::module_" is more up-to-date syntax,
  // but CI linter and some builds prefer "module".
  auto torch_C_m = py::handle(module).cast<py::module>();

  torch_C_m.def("_graph_pool_handle", &::at::cuda::graph_pool_handle);

  shared_ptr_class_<::at::cuda::CUDAGraph>(torch_C_m, "_CUDAGraph")
      .def(py::init<>())
      // I'm not sure this is the correct order of all the arguments. Pybind11
      // docs aren't clear. But it works.
      .def(
          "capture_begin",
          &::at::cuda::CUDAGraph::capture_begin,
          py::call_guard<py::gil_scoped_release>(),
          py::arg("pool") = c10::cuda::MempoolId_t{0, 0})
      .def(
          "capture_end",
          &::at::cuda::CUDAGraph::capture_end,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "replay",
          &::at::cuda::CUDAGraph::replay,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "reset",
          &::at::cuda::CUDAGraph::reset,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "pool",
          &::at::cuda::CUDAGraph::pool,
          py::call_guard<py::gil_scoped_release>());
}