1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
|
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/script_remote_call.h>
#include <c10/util/C++17.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace rpc {
ScriptRemoteCall::ScriptRemoteCall(
std::shared_ptr<Operator> op,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId)
: ScriptCall(std::move(op), std::move(stack)),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
ScriptRemoteCall::ScriptRemoteCall(
const c10::QualifiedName& qualifiedName,
std::vector<at::IValue>&& stack,
const RRefId& retRRefId,
const ForkId& retForkId,
const bool isAsyncExecution)
: ScriptCall(qualifiedName, std::move(stack), isAsyncExecution),
retRRefId_(retRRefId),
retForkId_(retForkId) {}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromIValues(
std::vector<at::IValue>& ivalues) {
// remove the last element from values and convert it back to an RRef
auto retForkId = RRefId::fromIValue(ivalues.back());
ivalues.pop_back();
auto retRRefId = ForkId::fromIValue(ivalues.back());
ivalues.pop_back();
auto scriptCallPtr = ScriptCall::fromIValues(ivalues);
if (scriptCallPtr->hasOp()) {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->op(), std::move(ivalues), retRRefId, retForkId);
} else {
return std::make_unique<ScriptRemoteCall>(
scriptCallPtr->qualifiedName(),
std::move(ivalues),
retRRefId,
retForkId,
scriptCallPtr->isAsyncExecution());
}
}
c10::intrusive_ptr<Message> ScriptRemoteCall::toMessageImpl() && {
std::vector<IValue> ivalues;
ScriptCall::toIValues(ivalues);
ivalues.emplace_back(retRRefId_.toIValue());
ivalues.emplace_back(retForkId_.toIValue());
std::vector<torch::Tensor> tensor_table;
auto payload = jit::pickle(
c10::ivalue::Tuple::create(std::move(ivalues)), &tensor_table);
return c10::make_intrusive<Message>(
std::move(payload),
std::move(tensor_table),
MessageType::SCRIPT_REMOTE_CALL);
}
std::unique_ptr<ScriptRemoteCall> ScriptRemoteCall::fromMessage(
const Message& message) {
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
auto value = jit::unpickle(
payload,
payload_size,
*RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
auto values = value.toTupleRef().elements().vec();
return fromIValues(values);
}
} // namespace rpc
} // namespace distributed
} // namespace torch
|