1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
|
#include <ATen/core/functional.h>
#include <c10/util/irange.h>
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
namespace torch {
namespace distributed {
namespace autograd {
using torch::autograd::Variable;
using torch::autograd::variable_list;
RecvRpcBackward::RecvRpcBackward(
const AutogradMetadata& autogradMetadata,
ContextPtr autogradContext,
rpc::worker_id_t fromWorkerId,
rpc::DeviceMap deviceMap)
: autogradMetadata_(autogradMetadata),
// NOLINTNEXTLINE(performance-move-const-arg)
autogradContext_(std::move(autogradContext)),
fromWorkerId_(fromWorkerId),
deviceMap_(std::move(deviceMap)) {}
variable_list RecvRpcBackward::apply(variable_list&& grads) {
std::vector<Variable> outputGrads;
for (const auto i : c10::irange(grads.size())) {
const auto& grad = grads[i];
if (grad.defined()) {
outputGrads.emplace_back(grad);
} else {
// Put in zeros for a tensor with no grad.
outputGrads.emplace_back(input_metadata(i).zeros_like());
}
}
auto sharedContext = autogradContext_.lock();
TORCH_CHECK(
sharedContext,
c10::str(
"Autograd context no longer valid! This usually ",
"means the autograd context was cleaned up by a different thread due ",
"to an error before RecvRcpBackward had a chance to run"));
// Send the gradients over the wire and record the future in the autograd
// context.
PropagateGradientsReq gradCall(
autogradMetadata_,
outputGrads,
sharedContext->retrieveGraphTask()->keep_graph_);
// Send the gradients over to the appropriate node.
auto rpcAgent = rpc::RpcAgent::getCurrentRpcAgent();
auto jitFuture = rpcAgent->send(
rpcAgent->getWorkerInfo(fromWorkerId_),
std::move(gradCall).toMessage(),
rpc::kUnsetRpcTimeout,
deviceMap_);
// Record the future in the context.
sharedContext->addOutstandingRpc(jitFuture);
// 'recv' function sends the gradients over the wire using RPC, it doesn't
// need to return anything for any downstream autograd function.
return variable_list();
}
} // namespace autograd
} // namespace distributed
} // namespace torch
|