1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159
|
#include <ATen/ThreadLocalState.h>
#include <fmt/format.h>
#include <torch/csrc/autograd/record_function_ops.h>
#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/profiler/remote_profiler_manager.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_proto.h>
#include <torch/csrc/distributed/rpc/script_call.h>
#include <torch/csrc/distributed/rpc/torchscript_functions.h>
#include <torch/csrc/distributed/rpc/utils.h>
namespace torch {
namespace distributed {
namespace rpc {
c10::intrusive_ptr<JitFuture> rpcTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
c10::intrusive_ptr<torch::autograd::profiler::PythonRecordFunction> record;
auto shouldProfile = torch::autograd::profiler::profilerEnabled() &&
!torch::distributed::rpc::RemoteProfilerManager::getInstance()
.isCurrentKeySet();
if (shouldProfile) {
auto rpcAsyncJitKey = fmt::format(
"rpc_async_jit#{}({} -> {})",
qualifiedName
.qualifiedName(), /* name of torchscript function being run */
RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
dstWorkerName);
record =
torch::autograd::profiler::record_function_enter_new(rpcAsyncJitKey);
auto& remoteProfilerManager =
torch::distributed::rpc::RemoteProfilerManager::getInstance();
remoteProfilerManager.setCurrentKey(rpcAsyncJitKey);
}
auto scriptCall = std::make_unique<ScriptCall>(
qualifiedName, std::move(stack), isAsyncExecution);
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
auto jitFuture = autograd::sendMessageWithAutograd(
*rpcAgentPtr,
rpcAgentPtr->getWorkerInfo(dstWorkerName),
std::move(*scriptCall).toMessage(),
true /*forceGradRecording*/,
rpcTimeoutSeconds);
// Get function return type to construct JitFuture.
auto returns = functionSchema.returns();
// Script call only allows single IValue returned.
TORCH_INTERNAL_ASSERT(
returns.size() == 1,
"Return value of an annotated torchScript function should be a single "
"IValue.",
returns.size());
auto returnType = returns.at(0).type();
// Create a JIT future and pass it to futMessage's callback to set state
// of the JIT future.
auto futPtr = jitFuture->createInstance(returnType);
jitFuture->addCallback(at::wrapPropagateTLSState([futPtr](JitFuture& future) {
if (future.hasError()) {
futPtr->setError(future.exception_ptr());
} else {
futPtr->markCompleted(
deserializeRespToIValue(
*future.constValue().toCustomClass<Message>()),
future.storages());
}
}));
if (shouldProfile) {
auto profiledFutPtr =
torch::autograd::profiler::_call_end_callbacks_on_fut_new(
record, futPtr);
return profiledFutPtr;
}
return futPtr;
}
c10::intrusive_ptr<RRef> remoteTorchscript(
const std::string& dstWorkerName,
const c10::QualifiedName& qualifiedName,
const c10::FunctionSchema& functionSchema,
std::vector<c10::IValue>& stack,
const float rpcTimeoutSeconds,
const bool isAsyncExecution) {
auto rpcAgentPtr = RpcAgent::getCurrentRpcAgent();
auto dstWorkerInfo = rpcAgentPtr->getWorkerInfo(dstWorkerName);
auto& ctx = RRefContext::getInstance();
// Get function return type to construct UserRRef.
auto returns = functionSchema.returns();
// Script call only allows single IValue returned.
TORCH_INTERNAL_ASSERT(
returns.size() == 1,
"Return value of an annotated torchScript function should be a single "
"IValue.",
returns.size());
auto returnType = returns.at(0).type();
if (ctx.getWorkerId() != dstWorkerInfo.id_) {
auto userRRefPtr = ctx.createUserRRef(dstWorkerInfo.id_, returnType);
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
qualifiedName,
std::move(stack),
userRRefPtr->rrefId(),
userRRefPtr->forkId(),
isAsyncExecution);
auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
*rpcAgentPtr,
dstWorkerInfo,
std::move(*scriptRemoteCall).toMessage(),
true /*forceGradRecording*/,
rpcTimeoutSeconds /* timeout */);
userRRefPtr->registerOwnerCreationFuture(jitFuture);
ctx.addPendingUser(userRRefPtr->forkId(), userRRefPtr);
jitFuture->addCallback(at::wrapPropagateTLSState(
[forkId{userRRefPtr->forkId()}](JitFuture& future) {
callback::confirmPendingUser(future, forkId);
}));
return userRRefPtr;
} else {
auto ownerRRefPtr = ctx.createOwnerRRef(returnType);
// prevent this owner RRef from being deleted due to other forks
ctx.addSelfAsFork(ownerRRefPtr);
auto scriptRemoteCall = std::make_unique<ScriptRemoteCall>(
qualifiedName,
std::move(stack),
ownerRRefPtr->rrefId(),
ownerRRefPtr->rrefId(),
isAsyncExecution);
auto jitFuture = torch::distributed::autograd::sendMessageWithAutograd(
*rpcAgentPtr,
dstWorkerInfo,
std::move(*scriptRemoteCall).toMessage(),
true /*forceGradRecording*/,
rpcTimeoutSeconds /* timeout */);
ownerRRefPtr->registerOwnerCreationFuture(jitFuture);
jitFuture->addCallback(at::wrapPropagateTLSState(
[ownerRRefId = ownerRRefPtr->rrefId()](JitFuture& future) {
callback::finishCreatingOwnerRRef(future, ownerRRefId);
}));
return ownerRRefPtr;
}
}
} // namespace rpc
} // namespace distributed
} // namespace torch
|