File: message.h

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (193 lines) | stat: -rw-r--r-- 7,528 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
#pragma once

#include <torch/types.h>
#include <vector>

namespace torch {
namespace distributed {
namespace rpc {

// An enum denoting common RPC errors to allow specific error handling for them.
enum RPCErrorType {
  UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
  TIMEOUT = 1, /* Indicates that the RPC has timed out */
  INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
                             FaultyAgent for testing */
};

// The enum values are bitwise ORed with MessageType
// They are bit flags starting from 0x100 and should have
// value such as 0x100, 0x200, 0x400, 0x800, 0xF00, etc.
enum MessageTypeFlags {
  REQUEST_TYPE = 0x100,
  RESPONSE_TYPE = 0x200,
};

// Message types must have values between 0x00 to 0xff
enum MessageType {
  // messages for dist.rpc on builtin operators
  SCRIPT_CALL = 0x00 | MessageTypeFlags::REQUEST_TYPE,
  SCRIPT_RET = 0x01 | MessageTypeFlags::RESPONSE_TYPE,

  // messages for dist.rpc on Python UDF
  PYTHON_CALL = 0x02 | MessageTypeFlags::REQUEST_TYPE,
  PYTHON_RET = 0x03 | MessageTypeFlags::RESPONSE_TYPE,

  // messages for dist.remote on builtin operators and Python UDF
  SCRIPT_REMOTE_CALL = 0x04 |
      MessageTypeFlags::REQUEST_TYPE, // A remote call on a builtin operator
  PYTHON_REMOTE_CALL =
      0x05 | MessageTypeFlags::REQUEST_TYPE, // A remote call on a Python UDF
  REMOTE_RET =
      0x06 | MessageTypeFlags::RESPONSE_TYPE, // Response for remote calls for
                                              // UDF, builtin, or script

  // RRef related internal messages
  SCRIPT_RREF_FETCH_CALL =
      0x07 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<IValue> fetches value
                                             // from owner
  PYTHON_RREF_FETCH_CALL =
      0x08 | MessageTypeFlags::REQUEST_TYPE, // A UserRRef<py::object> fetches
                                             // value from owner
  SCRIPT_RREF_FETCH_RET = 0x09 |
      MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends ivalue to user
  PYTHON_RREF_FETCH_RET = 0x0a |
      MessageTypeFlags::RESPONSE_TYPE, // An OwnerRRef sends py::object to user
  RREF_USER_DELETE = 0x0b |
      MessageTypeFlags::REQUEST_TYPE, // A UserRRef tells the owner to deref
  RREF_FORK_REQUEST =
      0x0c | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells the owner
                                             // about itself
  RREF_CHILD_ACCEPT =
      0x0d | MessageTypeFlags::REQUEST_TYPE, // A child UserRRef tells parent
                                             // that owner knows it
  RREF_ACK =
      0x0e | MessageTypeFlags::RESPONSE_TYPE, // ACK to internal RRef messages

  // Messages with autograd info
  FORWARD_AUTOGRAD_REQ = 0x0f | MessageTypeFlags::REQUEST_TYPE,
  FORWARD_AUTOGRAD_RESP = 0x10 | MessageTypeFlags::RESPONSE_TYPE,

  // Messages to propagate gradients on the backward pass.
  BACKWARD_AUTOGRAD_REQ = 0x11 | MessageTypeFlags::REQUEST_TYPE,
  BACKWARD_AUTOGRAD_RESP = 0x12 | MessageTypeFlags::RESPONSE_TYPE,

  // Messages to tell workers to clean up their autograd context.
  CLEANUP_AUTOGRAD_CONTEXT_REQ = 0x13 | MessageTypeFlags::REQUEST_TYPE,
  CLEANUP_AUTOGRAD_CONTEXT_RESP = 0x14 | MessageTypeFlags::RESPONSE_TYPE,

  // Messages that tell workers to run requests with profiling enabled.
  RUN_WITH_PROFILING_REQ = 0x15 | MessageTypeFlags::REQUEST_TYPE,
  RUN_WITH_PROFILING_RESP = 0x16 | MessageTypeFlags::RESPONSE_TYPE,

  // Messages to support RRef.backward().
  RREF_BACKWARD_REQ = 0x17 | MessageTypeFlags::REQUEST_TYPE,
  RREF_BACKWARD_RESP = 0x18 | MessageTypeFlags::RESPONSE_TYPE,

  // Other internal message types
  EXCEPTION = 0x37 | MessageTypeFlags::RESPONSE_TYPE,
  UNKNOWN = 0x3c
};

// A message to be sent/received by an RpcAgent.
//
// A Message object contains 4 fields:
//    payload (std::vector<char>): a binary chunk of data.
//    tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
//        included in the payload, and it is up to the RpcAgent implementation
//        to determine how to serialize them. This design is helpful for
//        communicating super large tensors where serializing all the data at
//        once leads to excessively large memory footprint. An implementation
//        can then serialize and send tensors chunck-by-chunk, in the streaming
//        fashion.
//    type (MessageType): type of the message.
//    id (int64_t): message id, this is used to match request and response.
//               Other implementation can ignore it if they have their own
//               ways to do matching.
//
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
// and PythonResp into a Message, and it is up to the RpcAgent
// implementation to determine how to serialize a message.
class TORCH_API Message final : public torch::CustomClassHolder {
 private:
  // Keep these private in order to force users to go through make_intrusive and
  // thus prevent creating a Message that's not held by an intrusive_ptr.
  Message();

  Message(
      std::vector<char>&& payload,
      std::vector<torch::Tensor>&& tensors,
      MessageType type);

  Message(
      std::vector<char>&& payload,
      std::vector<torch::Tensor>&& tensors,
      MessageType type,
      int64_t id);

  friend c10::intrusive_ptr<Message>;

 public:
  Message(const Message& other) = delete;
  Message(Message&& other) = delete;
  Message& operator=(Message const& rhs) = delete;
  Message& operator=(Message&& rhs) = delete;

  // Destructively retrieves the payload.
  std::vector<char>&& movePayload() &&;
  std::vector<torch::Tensor>&& moveTensors() &&;

  std::vector<char>& payload();
  const std::vector<char>& payload() const;
  std::vector<torch::Tensor>& tensors();
  const std::vector<torch::Tensor>& tensors() const;
  MessageType type() const;

  bool isRequest() const;
  bool isResponse() const;
  bool isShutdown() const;

  // id is an optional field to match request/response. If an RpcAgent
  // implementation is able to do the matching without using this id, it can be
  // dropped during message serialization.
  int64_t id() const;
  void setId(int64_t id);

  std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> getStorages() const;

 private:
  std::vector<char> payload_;
  std::vector<torch::Tensor> tensors_;
  MessageType type_ = MessageType::UNKNOWN;
  int64_t id_ = -1;
};

// Create a response Message of type Exception.
// The exception string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
    const std::exception& e,
    int64_t id);

// Create a response Message of type Exception.
// The passed in string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API c10::intrusive_ptr<Message> createExceptionResponse(
    const std::string& exceptionStr,
    int64_t id);

inline std::tuple<
    c10::intrusive_ptr<Message>,
    std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>>>
withStorages(c10::intrusive_ptr<Message> message) {
  auto storages = message->getStorages();
  return std::make_tuple(std::move(message), std::move(storages));
}

using JitFuture = c10::ivalue::Future;

} // namespace rpc
} // namespace distributed
} // namespace torch