1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
|
#include <ATen/ThreadLocalState.h>
#include <c10/util/ThreadLocalDebugInfo.h>
#include <torch/csrc/autograd/functions/utils.h>
#include <torch/csrc/autograd/profiler.h>
#include <torch/csrc/distributed/autograd/context/container.h>
#include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
#include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/types.h>
namespace torch {
namespace distributed {
namespace autograd {
using torch::distributed::autograd::AutogradMetadata;
using torch::distributed::autograd::RpcWithAutograd;
using torch::distributed::rpc::JitFuture;
using torch::distributed::rpc::Message;
using torch::distributed::rpc::MessageType;
using torch::distributed::rpc::RpcAgent;
using torch::distributed::rpc::RpcCommandBase;
using torch::distributed::rpc::WorkerInfo;
void addSendRpcBackward(
const ContextPtr& autogradContext,
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors) {
// Attach autograd information only for tensors requiring grad.
std::vector<torch::Tensor> tensors_with_grad;
std::copy_if(
tensors.begin(),
tensors.end(),
std::back_inserter(tensors_with_grad),
[](const torch::Tensor& t) { return t.requires_grad(); });
// Attach the appropriate autograd edges.
auto grad_fn = std::make_shared<SendRpcBackward>();
grad_fn->set_next_edges(
torch::autograd::collect_next_edges(tensors_with_grad));
// Add the appropriate input metadata for the grad_fn.
for (const auto& tensor : tensors_with_grad) {
grad_fn->add_input_metadata(tensor);
}
// Record the send autograd function in our current context.
autogradContext->addSendFunction(grad_fn, autogradMetadata.autogradMessageId);
}
ContextPtr addRecvRpcBackward(
const AutogradMetadata& autogradMetadata,
std::vector<torch::Tensor>& tensors,
rpc::worker_id_t fromWorkerId,
const rpc::DeviceMap& deviceMap) {
// Initialize autograd context if necessary.
auto& autogradContainer = DistAutogradContainer::getInstance();
auto autogradContext =
autogradContainer.getOrCreateContext(autogradMetadata.autogradContextId);
if (!tensors.empty() && torch::autograd::compute_requires_grad(tensors)) {
// Attach the tensors as inputs to the autograd function.
auto grad_fn = std::make_shared<RecvRpcBackward>(
autogradMetadata, autogradContext, fromWorkerId, deviceMap);
for (auto& tensor : tensors) {
if (tensor.requires_grad()) {
torch::autograd::set_history(tensor, grad_fn);
}
}
// Now update the autograd context with the necessary information.
autogradContext->addRecvFunction(
grad_fn, autogradMetadata.autogradMessageId);
}
return autogradContext;
}
c10::intrusive_ptr<Message> getMessageWithProfiling(
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMessage,
MessageType msgType,
torch::autograd::profiler::ProfilerConfig&& profilerConfig) {
auto& remoteProfilerManager =
torch::distributed::rpc::RemoteProfilerManager::getInstance();
auto key = remoteProfilerManager.getCurrentProfilingKey();
// generate a globally unique Id
auto globallyUniqueProfilingId = remoteProfilerManager.getNextProfilerId();
// Save a mapping of ID -> RPC profiling key and unset the current TLS key.
remoteProfilerManager.saveRPCKey(globallyUniqueProfilingId, key);
remoteProfilerManager.unsetCurrentKey();
auto wrappedProfilingMsg = RpcWithProfilingReq(
msgType,
std::move(wrappedRpcMessage),
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(profilerConfig),
globallyUniqueProfilingId);
return std::move(wrappedProfilingMsg).toMessage();
}
c10::intrusive_ptr<Message> getMessageWithAutograd(
const rpc::worker_id_t dstId,
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
MessageType msgType,
bool forceGradRecording,
const rpc::DeviceMap& deviceMap) {
auto& autogradContainer = DistAutogradContainer::getInstance();
// If there is no valid context and no tensor requires grads, send original
// rpc message. otherwise, attach grad info and grad functions and send
// rpcWithAutograd message.
auto tensorsRequireGrad =
torch::autograd::compute_requires_grad(wrappedRpcMsg->tensors());
if (!autogradContainer.hasValidContext() ||
(!forceGradRecording && !tensorsRequireGrad)) {
return wrappedRpcMsg;
}
// Retrieve the appropriate context to modify.
auto autogradContext = autogradContainer.currentContext();
// Wrap the original rpc with autograd information.
AutogradMetadata autogradMetadata(
autogradContext->contextId(), autogradContainer.newAutogradMessageId());
auto rpcWithAutograd = std::make_unique<RpcWithAutograd>(
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().id_,
msgType,
autogradMetadata,
std::move(wrappedRpcMsg),
deviceMap);
if (tensorsRequireGrad) {
// Record autograd information for 'send'.
addSendRpcBackward(
autogradContext, autogradMetadata, rpcWithAutograd->tensors());
}
// Record the workerID
autogradContext->addKnownWorkerId(dstId);
return std::move(*rpcWithAutograd).toMessage();
}
c10::intrusive_ptr<JitFuture> sendMessageWithAutograd(
RpcAgent& agent,
const WorkerInfo& dst,
c10::intrusive_ptr<torch::distributed::rpc::Message> wrappedRpcMsg,
bool forceGradRecording,
const float rpcTimeoutSeconds,
bool forceDisableProfiling) {
auto msg = getMessageWithAutograd(
dst.id_,
std::move(wrappedRpcMsg),
MessageType::FORWARD_AUTOGRAD_REQ,
forceGradRecording,
agent.getDeviceMap(dst));
// If profiler is enabled, wrap this message with profiling metadata that will
// tell the remote end to process this request with the profiler enabled.
if (!forceDisableProfiling) {
switch (torch::profiler::impl::profilerType()) {
case torch::profiler::impl::ActiveProfilerType::LEGACY: {
auto profilerConfig = torch::autograd::profiler::getProfilerConfig();
auto msgWithProfiling = getMessageWithProfiling(
std::move(msg),
rpc::MessageType::RUN_WITH_PROFILING_REQ,
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(profilerConfig));
return agent.send(dst, std::move(msgWithProfiling), rpcTimeoutSeconds);
}
case torch::profiler::impl::ActiveProfilerType::KINETO:
TORCH_WARN_ONCE(
"Profiling a distributed call with the Kineto profiler will profile "
"the caller, but not the worker.");
break;
default:
break;
}
}
return agent.send(dst, std::move(msg), rpcTimeoutSeconds);
;
}
} // namespace autograd
} // namespace distributed
} // namespace torch
|