1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193
|
#pragma once
#include <torch/types.h>
#include <vector>
namespace torch {
namespace distributed {
namespace rpc {
// An enum denoting common RPC errors to allow specific error handling for them.
enum RPCErrorType {
UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
TIMEOUT = 1, /* Indicates that the RPC has timed out */
INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
FaultyAgent for testing */
};
// The enum values are bitwise ORed with MessageType
// They are bit flags starting from 0x100 and should have
// value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc.
enum MessageTypeFlags {
REQUEST_TYPE = 0x100,
RESPONSE_TYPE = 0x200,
};
// Message types must have values between 0x00 to 0xff
enum MessageType {
// messages for dist.rpc on builtin operators
SCRIPT_CALL = 0x00 | MessageTypeFlags::REQUEST_TYPE,
SCRIPT_RET = 0x01 | MessageTypeFlags::RESPONSE_TYPE,
// messages for dist.rpc on Python UDF
PYTHON_CALL = 0x02 | MessageTypeFlags::REQUEST_TYPE,
PYTHON_RET = 0x03 | MessageTypeFlags::RESPONSE_TYPE,
// messages for dist.remote on builtin operators and Python UDF
SCRIPT_REMOTE_CALL = 0x04 |
MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator
PYTHON_REMOTE_CALL =
0x05 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF
REMOTE_RET =
0x06 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for
// UDF, builtin, or script
// RRef related internal messages
SCRIPT_RREF_FETCH_CALL =
0x07 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<IValue> fetches value
// from owner
PYTHON_RREF_FETCH_CALL =
0x08 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<py::object> fetches
// value from owner
SCRIPT_RREF_FETCH_RET = 0x09 |
MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user
PYTHON_RREF_FETCH_RET = 0x0a |
MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user
RREF_USER_DELETE = 0x0b |
MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref
RREF_FORK_REQUEST =
0x0c | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner
// about itself
RREF_CHILD_ACCEPT =
0x0d | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent
// that owner knows it
RREF_ACK =
0x0e | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages
// Messages with autograd info
FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,
// Messages to propagate gradients on the backward pass.
BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,
// Messages to tell workers to clean up their autograd context.
CLEANUP_AUTOGRAD_CONTEXT_REQ = 0x13 | MessageTypeFlags::REQUEST_TYPE,
CLEANUP_AUTOGRAD_CONTEXT_RESP = 0x14 | MessageTypeFlags::RESPONSE_TYPE,
// Messages that tell workers to run requests with profiling enabled.
RUN_WITH_PROFILING_REQ = 0x15 | MessageTypeFlags::REQUEST_TYPE,
RUN_WITH_PROFILING_RESP = 0x16 | MessageTypeFlags::RESPONSE_TYPE,
// Messages to support RRef.backward().
RREF_BACKWARD_REQ = 0x17 | MessageTypeFlags::REQUEST_TYPE,
RREF_BACKWARD_RESP = 0x18 | MessageTypeFlags::RESPONSE_TYPE,
// Other internal message types
EXCEPTION = 0x37 | MessageTypeFlags::RESPONSE_TYPE,
UNKNOWN = 0x3c
};
// A message to be sent/received by an RpcAgent.
//
// A Message object contains 4 fields:
// payload (std::vector<char>): a binary chunk of data.
// tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
// included in the payload, and it is up to the RpcAgent implementation
// to determine how to serialize them. This design is helpful for
// communicating super large tensors where serializing all the data at
// once leads to excessively large memory footprint. An implementation
// can then serialize and send tensors chunck-by-chunk, in the streaming
// fashion.
// type (MessageType): type of the message.
// id (int64_t): message id, this is used to match request and response.
// Other implementation can ignore it if they have their own
// ways to do matching.
//
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
// and PythonResp into a Message, and it is up to the RpcAgent
// implementation to determine how to serialize a message.
class TORCH_API Message final : public torch::CustomClassHolder {
private:
// Keep these private in order to force users to go through make_intrusive and
// thus prevent creating a Message that's not held by an intrusive_ptr.
Message();
Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type);
Message(
std::vector<char>&& payload,
std::vector<torch::Tensor>&& tensors,
MessageType type,
int64_t id);
friend c10::intrusive_ptr<Message>;
public:
Message(const Message& other) = delete;
Message(Message&& other) = delete;
Message& operator=(Message const& rhs) = delete;
Message& operator=(Message&& rhs) = delete;
// Destructively retrieves the payload.
std::vector<char>&& movePayload() &&;
std::vector<torch::Tensor>&& moveTensors() &&;
std::vector<char>& payload();
const std::vector<char>& payload() const;
std::vector<torch::Tensor>& tensors();
const std::vector<torch::Tensor>& tensors() const;
MessageType type() const;
bool isRequest() const;
bool isResponse() const;
bool isShutdown() const;
// id is an optional field to match request/response. If an RpcAgent
// implementation is able to do the matching without using this id, it can be
// dropped during message serialization.
int64_t id() const;
void setId(int64_t id);
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> getStorages() const;
private:
std::vector<char> payload_;
std::vector<torch::Tensor> tensors_;
MessageType type_ = MessageType::UNKNOWN;
int64_t id_ = -1;
};
// Create a response Message of type Exception.
// The exception string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
const std::exception& e,
int64_t id);
// Create a response Message of type Exception.
// The passed in string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
const std::string& exceptionStr,
int64_t id);
inline std::tuple<
c10::intrusive_ptr<Message>,
std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>>>
withStorages(c10::intrusive_ptr<Message> message) {
auto storages = message->getStorages();
return std::make_tuple(std::move(message), std::move(storages));
}
using JitFuture = c10::ivalue::Future;
} // namespace rpc
} // namespace distributed
} // namespace torch
|