File: init.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (125 lines) | stat: -rw-r--r-- 4,011 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#include <torch/csrc/python_headers.h>

#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
#include <torch/csrc/utils/pybind.h>

#include <pybind11/chrono.h>

namespace torch {
namespace distributed {
namespace rpc {
namespace testing {

namespace {

template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;

PyObject* faulty_agent_init(PyObject* /* unused */) {
  // Add the FaultyProcessGroupAgent and its backend options object to the
  // python module torch.distributed.rpc._testing
  auto faulty_agent_module =
      THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc._testing"));
  if (!faulty_agent_module) {
    throw python_error();
  }

  auto module = py::handle(faulty_agent_module).cast<py::module>();

  // Import the rpc_module so we can subclass ProcessGroupAgent
  py::module rpc_module = py::module::import("torch.distributed.rpc");

  shared_ptr_class_<FaultyProcessGroupRpcBackendOptions>(
      module,
      "FaultyProcessGroupRpcBackendOptions",
      rpc_module.attr("ProcessGroupRpcBackendOptions"))
      .def(
          py::init<
              int,
              float,
              std::string,
              std::vector<std::string>,
              std::unordered_map<std::string, float>,
              int>(),
          py::arg("num_send_recv_threads"),
          py::arg("rpc_timeout"),
          py::arg("init_method"),
          py::arg("messages_to_fail"),
          py::arg("messages_to_delay"),
          py::arg("num_fail_sends"))
      .def_readwrite(
          "num_send_recv_threads",
          &ProcessGroupRpcBackendOptions::numSendRecvThreads)
      .def_readwrite(
          "messages_to_fail",
          &FaultyProcessGroupRpcBackendOptions::messagesToFail)
      .def_readwrite(
          "messages_to_delay",
          &FaultyProcessGroupRpcBackendOptions::messagesToDelay)
      .def_readwrite(
          "num_fail_sends", &FaultyProcessGroupRpcBackendOptions::numFailSends);

  shared_ptr_class_<FaultyProcessGroupAgent>(
      module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
      .def(
          py::init<
              std::string,
              std::shared_ptr<::c10d::ProcessGroup>,
              int,
              std::chrono::milliseconds,
              const std::vector<std::string>&,
              const std::unordered_map<std::string, float>&,
              int>(),
          py::arg("name"),
          py::arg("process_group"),
          py::arg("num_send_recv_threads"),
          py::arg("rpc_timeout"),
          py::arg("messages_to_fail"),
          py::arg("messages_to_delay"),
          py::arg("failNumSends"))
      .def(
          "join",
          &ProcessGroupAgent::join,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "shutdown",
          &ProcessGroupAgent::shutdown,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "get_worker_info",
          (const WorkerInfo& (ProcessGroupAgent::*)(void)const) &
              RpcAgent::getWorkerInfo,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "get_worker_info",
          (const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) &
              ProcessGroupAgent::getWorkerInfo,
          py::call_guard<py::gil_scoped_release>())
      .def(
          "get_worker_infos",
          (std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) &
              ProcessGroupAgent::getWorkerInfos,
          py::call_guard<py::gil_scoped_release>());

  Py_RETURN_TRUE;
}

} // namespace

static PyMethodDef methods[] = { // NOLINT
    {"_faulty_agent_init",
     (PyCFunction)faulty_agent_init,
     METH_NOARGS,
     nullptr},
    {nullptr, nullptr, 0, nullptr}};

PyMethodDef* python_functions() {
  return methods;
}

} // namespace testing
} // namespace rpc
} // namespace distributed
} // namespace torch