1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
#pragma once
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/utils/pybind.h>
#include <torch/csrc/jit/python/pybind_utils.h>
namespace c10d {
// PyProcessGroup is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.ProcessGroup
class PyProcessGroup : public ProcessGroup {
public:
// PyWork is a pybind11 trampoline class to allow a Python
// class to inherit from torch.distributed.Work
class PyWork : public Work {
public:
PyWork() = default;
bool wait(std::chrono::milliseconds timeout = kNoTimeout) override {
PYBIND11_OVERRIDE(
bool, /* Return type */
Work, /* Parent class */
wait, /* Name of function in C++ */
timeout);
}
c10::intrusive_ptr<c10::ivalue::Future> getFuture() override {
// We cannot use PYBIND11_OVERRIDE because:
// 1. We have to >MANUALLY< unwrap the PyFutureWrapper and
// 2. The python name is get_future
pybind11::gil_scoped_acquire gil;
auto override = pybind11::get_override(static_cast<const Work *>(this), "get_future");
if (override) {
py::object o = override();
auto futWrapper = o.cast<std::shared_ptr<torch::jit::PythonFutureWrapper>>();
return futWrapper->fut;
}
return Work::getFuture();
}
};
using ProcessGroup::ProcessGroup;
const std::string getBackendName() const override {
PYBIND11_OVERRIDE_PURE(
std::string, /* Return type */
ProcessGroup, /* Parent class */
getBackendName, /* Name of function in C++ */
);
}
c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
allgather, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
allreduce, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<Work> barrier(
const BarrierOptions& opts = BarrierOptions()) {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
barrier, /* Name of function in C++ */
opts);
}
c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& tensors,
const BroadcastOptions& opts = BroadcastOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
broadcast, /* Name of function in C++ */
tensors,
opts);
}
c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
reduce_scatter, /* Name of function in C++ */
outputTensors,
inputTensors,
opts);
}
c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
send, /* Name of function in C++ */
tensors,
dstRank,
tag);
}
c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override {
PYBIND11_OVERRIDE(
c10::intrusive_ptr<Work>, /* Return type */
ProcessGroup, /* Parent class */
recv, /* Name of function in C++ */
tensors,
srcRank,
tag);
}
};
} // namespace c10d
|