1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
|
#include <torch/csrc/python_headers.h>
#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
#include <torch/csrc/utils/pybind.h>
#include <pybind11/chrono.h>
namespace torch {
namespace distributed {
namespace rpc {
namespace testing {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* faulty_agent_init(PyObject* /* unused */) {
// Add the FaultyProcessGroupAgent and its backend options object to the
// python module torch.distributed.rpc._testing
auto faulty_agent_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.rpc._testing"));
if (!faulty_agent_module) {
throw python_error();
}
auto module = py::handle(faulty_agent_module).cast<py::module>();
// Import the rpc_module so we can subclass ProcessGroupAgent
py::module rpc_module = py::module::import("torch.distributed.rpc");
shared_ptr_class_<FaultyProcessGroupRpcBackendOptions>(
module,
"FaultyProcessGroupRpcBackendOptions",
rpc_module.attr("ProcessGroupRpcBackendOptions"))
.def(
py::init<
int,
float,
std::string,
std::vector<std::string>,
std::unordered_map<std::string, float>,
int>(),
py::arg("num_send_recv_threads"),
py::arg("rpc_timeout"),
py::arg("init_method"),
py::arg("messages_to_fail"),
py::arg("messages_to_delay"),
py::arg("num_fail_sends"))
.def_readwrite(
"num_send_recv_threads",
&ProcessGroupRpcBackendOptions::numSendRecvThreads)
.def_readwrite(
"messages_to_fail",
&FaultyProcessGroupRpcBackendOptions::messagesToFail)
.def_readwrite(
"messages_to_delay",
&FaultyProcessGroupRpcBackendOptions::messagesToDelay)
.def_readwrite(
"num_fail_sends", &FaultyProcessGroupRpcBackendOptions::numFailSends);
shared_ptr_class_<FaultyProcessGroupAgent>(
module, "FaultyProcessGroupAgent", rpc_module.attr("ProcessGroupAgent"))
.def(
py::init<
std::string,
std::shared_ptr<::c10d::ProcessGroup>,
int,
std::chrono::milliseconds,
const std::vector<std::string>&,
const std::unordered_map<std::string, float>&,
int>(),
py::arg("name"),
py::arg("process_group"),
py::arg("num_send_recv_threads"),
py::arg("rpc_timeout"),
py::arg("messages_to_fail"),
py::arg("messages_to_delay"),
py::arg("failNumSends"))
.def(
"join",
&ProcessGroupAgent::join,
py::call_guard<py::gil_scoped_release>())
.def(
"shutdown",
&ProcessGroupAgent::shutdown,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (ProcessGroupAgent::*)(void)const) &
RpcAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_info",
(const WorkerInfo& (ProcessGroupAgent::*)(const std::string&)const) &
ProcessGroupAgent::getWorkerInfo,
py::call_guard<py::gil_scoped_release>())
.def(
"get_worker_infos",
(std::vector<WorkerInfo>(ProcessGroupAgent::*)() const) &
ProcessGroupAgent::getWorkerInfos,
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_faulty_agent_init",
(PyCFunction)faulty_agent_init,
METH_NOARGS,
nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace testing
} // namespace rpc
} // namespace distributed
} // namespace torch
|