1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
#include <torch/csrc/distributed/autograd/rpc_messages/rref_backward_req.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/jit/serialization/pickle.h>
namespace torch {
namespace distributed {
namespace autograd {
using rpc::Message;
using rpc::MessageType;
RRefBackwardReq::RRefBackwardReq(
const rpc::RRefId& rrefId,
int64_t autogradContextId,
bool retainGraph)
: rrefId_(rrefId),
autogradContextId_(autogradContextId),
retainGraph_(retainGraph) {}
c10::intrusive_ptr<Message> RRefBackwardReq::toMessageImpl() && {
std::vector<at::IValue> ivalues;
// Add all the fields.
ivalues.emplace_back(rrefId_.toIValue());
ivalues.emplace_back(autogradContextId_);
ivalues.emplace_back(retainGraph_);
// Now pickle using JIT pickler.
std::vector<torch::Tensor> tensorTable;
std::vector<char> payload =
jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
return c10::make_intrusive<Message>(
std::move(payload),
std::move(tensorTable),
MessageType::RREF_BACKWARD_REQ);
}
std::unique_ptr<RRefBackwardReq> RRefBackwardReq::fromMessage(
const Message& message) {
// Unpickle the message and retrieve tupleElements.
auto payload = static_cast<const char*>(message.payload().data());
auto payload_size = message.payload().size();
IValue tuple = jit::unpickle(
payload,
payload_size,
*rpc::RpcAgent::getCurrentRpcAgent()->getTypeResolver(),
message.tensors());
const auto& tupleElements = std::move(*std::move(tuple).toTuple()).elements();
// Build RRefBackwardReq.
TORCH_INTERNAL_ASSERT(tupleElements.size() == 3);
// Retrieve all fields.
bool retainGraph = tupleElements[2].toBool();
int64_t autogradContextId = tupleElements[1].toInt();
rpc::RRefId rrefId = rpc::RRefId::fromIValue(tupleElements[0]);
return std::make_unique<RRefBackwardReq>(
rrefId, autogradContextId, retainGraph);
}
const rpc::RRefId& RRefBackwardReq::getRRefId() const {
return rrefId_;
}
int64_t RRefBackwardReq::getAutogradContextId() const {
return autogradContextId_;
}
bool RRefBackwardReq::retainGraph() const {
return retainGraph_;
}
} // namespace autograd
} // namespace distributed
} // namespace torch
|