1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
|
#include <torch/csrc/distributed/rpc/message.h>
namespace torch {
namespace distributed {
namespace rpc {
Message::Message() = default;
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type)
: payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}
Message::Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id)
: payload_(std::move(payload)),
tensors_(std::move(tensors)),
type_(type),
id_(id) {}
Message::Message(const Message& other) = default;
Message::Message(Message&& other) noexcept = default;
Message& Message::operator=(Message const& rhs) & {
auto payload = rhs.payload_;
auto tensors = rhs.tensors_;
Message(std::move(payload), std::move(tensors), rhs.type_, rhs.id_)
.swap(*this);
return *this;
}
Message& Message::operator=(Message&& rhs) & {
Message(std::move(rhs.payload_), std::move(rhs.tensors_), rhs.type_, rhs.id_)
.swap(*this);
return *this;
}
void Message::swap(Message& rhs) noexcept {
std::swap(payload_, rhs.payload_);
std::swap(tensors_, rhs.tensors_);
std::swap(type_, rhs.type_);
std::swap(id_, rhs.id_);
}
std::vector<char>&& Message::movePayload() && {
return std::move(payload_);
}
std::vector<char>& Message::payload() {
return payload_;
}
const std::vector<char>& Message::payload() const {
return payload_;
}
std::vector<torch::Tensor>&& Message::moveTensors() && {
return std::move(tensors_);
}
std::vector<torch::Tensor>& Message::tensors() {
return tensors_;
}
const std::vector<torch::Tensor>& Message::tensors() const {
return tensors_;
}
MessageType Message::type() const {
return type_;
}
bool Message::isRequest() const {
return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops
MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs
MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops
MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs
// RRef related internal messages
MessageType::SCRIPT_RREF_FETCH_CALL == type_ ||
MessageType::PYTHON_RREF_FETCH_CALL == type_ ||
MessageType::RREF_USER_DELETE == type_ ||
MessageType::RREF_CHILD_ACCEPT == type_ ||
MessageType::RREF_FORK_REQUEST == type_ ||
// Autograd message
MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
MessageType::FORWARD_AUTOGRAD_REQ == type_ ||
// Cleanup Autograd context request
MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ ||
// Run with profiling request
MessageType::RUN_WITH_PROFILING_REQ == type_;
}
bool Message::isResponse() const {
return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops
MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs
MessageType::REMOTE_RET == type_ || // ret of dist.remote
MessageType::SCRIPT_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
MessageType::PYTHON_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
MessageType::EXCEPTION == type_ || // propagate back exceptions
MessageType::RREF_ACK == type_ || // ret of other types
// Autograd response
MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
MessageType::FORWARD_AUTOGRAD_RESP == type_ ||
// Cleanup autograd context response
MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ ||
// Run with profiling response
MessageType::RUN_WITH_PROFILING_RESP == type_;
}
int64_t Message::id() const {
return id_;
}
void Message::setId(int64_t id) {
id_ = id;
}
Message createExceptionResponse(const std::exception& e, int64_t id) {
return createExceptionResponse(e.what(), id);
}
Message createExceptionResponse(const std::string& exceptionStr, int64_t id) {
std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
return Message(
std::move(payload),
std::vector<torch::Tensor>(),
MessageType::EXCEPTION,
id);
}
} // namespace rpc
} // namespace distributed
} // namespace torch
|