1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
|
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/distributed/autograd/autograd.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/utils/pybind.h>
#include <torch/types.h>
namespace torch {
namespace distributed {
namespace autograd {
namespace {
template <typename T>
using shared_ptr_class_ = py::class_<T, std::shared_ptr<T>>;
PyObject* dist_autograd_init(PyObject* _unused, PyObject* noargs) {
auto autograd_module =
THPObjectPtr(PyImport_ImportModule("torch.distributed.autograd"));
if (!autograd_module) {
throw python_error();
}
auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
if (!torch_C_module) {
throw python_error();
}
auto torch_C_m = py::handle(torch_C_module).cast<py::module>();
auto m = torch_C_m.def_submodule(
"_distributed_autograd", "distributed autograd bindings");
auto module = py::handle(m).cast<py::module>();
auto distAutogradContext =
shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
.def(
"_context_id",
&DistAutogradContext::contextId,
py::call_guard<py::gil_scoped_release>())
.def(
"_recv_functions",
[](const DistAutogradContext& ctx) {
std::map<int64_t, py::object> funcs;
auto recvFunctions = ctx.recvFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : recvFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_send_functions",
[](const ContextPtr& ctx) {
std::map<int64_t, py::object> funcs;
auto sendFunctions = ctx->sendFunctions();
// Acquire GIL only when necessary to avoid deadlocks.
pybind11::gil_scoped_acquire ag;
for (const auto& map_entry : sendFunctions) {
funcs.emplace(
map_entry.first,
py::reinterpret_steal<py::object>(
torch::autograd::functionToPyObject(
map_entry.second)));
}
return funcs;
},
py::call_guard<py::gil_scoped_release>())
.def(
"_known_worker_ids",
&DistAutogradContext::getKnownWorkerIds,
py::call_guard<py::gil_scoped_release>());
module.def(
"_new_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().newContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_release_context",
[](int64_t context_id) {
return DistAutogradContainer::getInstance().releaseContext(context_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_max_id",
[]() { return DistAutogradContainer::getInstance().getMaxId(); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_is_valid_context",
[](int64_t worker_id) {
DistAutogradContainer::getInstance().isValidContext(worker_id);
},
py::call_guard<py::gil_scoped_release>());
module.def(
"_retrieve_context",
[](int64_t context_id) -> const ContextPtr {
return DistAutogradContainer::getInstance().retrieveContext(context_id);
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_current_context",
[]() -> const ContextPtr {
return DistAutogradContainer::getInstance().currentContext();
},
py::return_value_policy::reference,
py::call_guard<py::gil_scoped_release>());
module.def(
"_init",
[](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
py::call_guard<py::gil_scoped_release>());
module.def(
"_get_debug_info",
[]() { return DistEngine::getInstance().getDebugInfo(); },
py::call_guard<py::gil_scoped_release>());
py::options options;
options.disable_function_signatures();
module.def(
"backward",
backward,
R"(
backward(context_id: int, roots: List[Tensor], retain_graph = False) -> None
Kicks off the distributed backward pass using the provided roots. This
currently implements the :ref:`fast-mode-algorithm` which
assumes all RPC messages sent in the same distributed autograd context
across workers would be part of the autograd graph during the backward pass.
We use the provided roots to discover the autograd graph and compute
appropriate dependencies. This method blocks until the entire
autograd computation is done.
We accumulate the gradients in the appropriate
:class:`torch.distributed.autograd.context` on each of the nodes. The autograd
context to be used is looked up given the ``context_id`` that is passed in when
:meth:`torch.distributed.autograd.backward` is called. If there is no valid
autograd context corresponding to the given ID, we throw an error. You can
retrieve the accumulated gradients using the
:meth:`~torch.distributed.autograd.get_gradients` API.
Arguments:
context_id (int): The autograd context id for which we should retrieve the gradients.
roots (list): Tensors which represent the roots of the autograd
computation. All the tensors should be scalars.
retain_graph(bool, optional): If False, the graph used to compute the grad
will be freed. Note that in nearly all cases setting this
option to True is not needed and often can be worked around
in a much more efficient way. Usually, you need to set this
to True to run backward multiple times.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> pred = model.forward()
>>> loss = loss_func(pred, loss)
>>> dist_autograd.backward(context_id, loss)
)",
py::arg("contextId"),
py::arg("roots"),
py::arg("retain_graph") = false,
py::call_guard<py::gil_scoped_release>());
module.def(
"get_gradients",
[](int64_t contextId) -> py::dict {
const auto& autogradContext =
DistAutogradContainer::getInstance().retrieveContext(contextId);
auto ival = IValue(autogradContext->getGradients());
// Acquire GIL only for pyobject conversion.
pybind11::gil_scoped_acquire ag;
return torch::jit::toPyObject(ival);
},
R"(
get_gradients(context_id: int) -> Dict[Tensor, Tensor]
Retrieves a map from Tensor to the appropriate gradient for that Tensor
accumulated in the provided context corresponding to the given ``context_id``
as part of the distributed autograd backward pass.
Arguments:
context_id(int): The autograd context id for which we should retrieve the
gradients.
Returns:
A map where the key is the Tensor and the value is the associated gradient
for that Tensor.
Example::
>>> import torch.distributed.autograd as dist_autograd
>>> with dist_autograd.context() as context_id:
>>> t1 = torch.rand((3, 3), requires_grad=True)
>>> t2 = torch.rand((3, 3), requires_grad=True)
>>> loss = t1 + t2
>>> dist_autograd.backward(context_id, [loss.sum()])
>>> grads = dist_autograd.get_gradients(context_id)
>>> print(grads[t1])
>>> print(grads[t2])
)",
py::arg("context_id"),
py::call_guard<py::gil_scoped_release>());
Py_RETURN_TRUE;
}
} // namespace
static PyMethodDef methods[] = { // NOLINT
{"_dist_autograd_init", dist_autograd_init, METH_NOARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
PyMethodDef* python_functions() {
return methods;
}
} // namespace autograd
} // namespace distributed
} // namespace torch
|