File: message.h

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (151 lines) | stat: -rw-r--r-- 5,514 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#pragma once

#include <torch/csrc/utils/future.h>
#include <torch/types.h>
#include <vector>

namespace torch {
namespace distributed {
namespace rpc {

// An enum denoting common RPC errors to allow specific error handling for them.
enum RPCErrorType {
  UNKNOWN_ERROR = 0, /* Indicates that error type could not be parsed */
  TIMEOUT = 1, /* Indicates that the RPC has timed out */
  INTENTIONAL_FAILURE = 2 /* Deliberate failure, such as those injected by
                             FaultyProcessGroupAgent for testing */
};

enum MessageType {
  // messages for dist.rpc on builtin operators
  SCRIPT_CALL = 0,
  SCRIPT_RET = 1,

  // messages for dist.rpc on Python UDF
  PYTHON_CALL = 2,
  PYTHON_RET = 3,

  // messages for dist.remote on builtin operators and Python UDF
  SCRIPT_REMOTE_CALL = 4, // A remote call on a builtin operator
  PYTHON_REMOTE_CALL = 5, // A remote call on a Python UDF
  REMOTE_RET = 6, // Response for remote calls for UDF, builtin, or script

  // RRef related internal messages
  SCRIPT_RREF_FETCH_CALL = 7, // A UserRRef<IValue> fetches value from owner
  PYTHON_RREF_FETCH_CALL = 8, // A UserRRef<py::object> fetches value from owner
  SCRIPT_RREF_FETCH_RET = 9, // An OwnerRRef sends ivalue to user
  PYTHON_RREF_FETCH_RET = 10, // An OwnerRRef sends py::object to user
  RREF_USER_DELETE = 11, // A UserRRef tells the owner to deref
  RREF_FORK_REQUEST = 12, // A child UserRRef tells the owner about itself
  RREF_CHILD_ACCEPT = 13, // A child UserRRef tells parent that owner knows it
  RREF_ACK = 14, // ACK to internal RRef messages

  // Messages with autograd info
  FORWARD_AUTOGRAD_REQ = 15,
  FORWARD_AUTOGRAD_RESP = 16,

  // Messages to propagate gradients on the backward pass.
  BACKWARD_AUTOGRAD_REQ = 17,
  BACKWARD_AUTOGRAD_RESP = 18,

  // Messages to tell workers to clean up their autograd context.
  CLEANUP_AUTOGRAD_CONTEXT_REQ = 19,
  CLEANUP_AUTOGRAD_CONTEXT_RESP = 20,

  // Messages that tell workers to run requests with profiling enabled.
  RUN_WITH_PROFILING_REQ = 21,
  RUN_WITH_PROFILING_RESP = 22,

  // Other internal message types
  EXCEPTION = 55,
  UNKNOWN = 60
};

// A message to be sent/received by an RpcAgent.
//
// A Message object contains 4 fields:
//    payload (std::vector<char>): a binary chunk of data.
//    tensors (std::vector<torch::Tensor>): all tensors. Tensor data are not
//        included in the payload, and it is up to the RpcAgent implementation
//        to determine how to serialize them. This design is helpful for
//        communicating super large tensors where serializing all the data at
//        once leads to excessively large memory footprint. An implementation
//        can then serialize and send tensors chunck-by-chunk, in the streaming
//        fashion.
//    type (MessageType): type of the message.
//    id (int64_t): message id, this is used by ProcessGroupAgent to match
//                  request and response. Other implementation can ignore it
//                  if they have their own ways to do matching.
//
// Layers above ``RpcAgent`` only converts ScriptCall, ScriptResp, PythonCall,
// and PythonResp into a Message, and it is up to the RpcAgent
// implementation to determine how to serialize a message.
class TORCH_API Message final {
 public:
  Message();

  Message(
      std::vector<char>&& payload,
      std::vector<torch::Tensor>&& tensors,
      MessageType type);

  Message(
      std::vector<char>&& payload,
      std::vector<torch::Tensor>&& tensors,
      MessageType type,
      int64_t id);

  Message(const Message& other);
  Message(Message&& other) noexcept;
  Message& operator=(Message const& rhs) &;
  Message& operator=(Message&& rhs) &;
  void swap(Message& rhs) noexcept;

  // Destructively retrieves the payload.
  std::vector<char>&& movePayload() &&;
  std::vector<torch::Tensor>&& moveTensors() &&;

  std::vector<char>& payload();
  const std::vector<char>& payload() const;
  std::vector<torch::Tensor>& tensors();
  const std::vector<torch::Tensor>& tensors() const;
  MessageType type() const;

  bool isRequest() const;
  bool isResponse() const;
  bool isShutdown() const;

  // id is an optional field to match request/response. If an RpcAgent
  // implementation is able to do the matching without using this id, it can be
  // dropped during message serialization.
  int64_t id() const;
  void setId(int64_t id);

 private:
  std::vector<char> payload_;
  std::vector<torch::Tensor> tensors_;
  MessageType type_ = MessageType::UNKNOWN;
  int64_t id_ = -1;
};

// Create a response Message of type Exception.
// The exception string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API Message createExceptionResponse(const std::exception& e, int64_t id);

// Create a response Message of type Exception.
// The passed in string representation will be used as the message's payload.
// A message ID corresponding to the request that resulted in this response can
// be provided for matching requests/responses.
TORCH_API Message
createExceptionResponse(const std::string& exceptionStr, int64_t id);

// FutureMessage is an internal type used in the communication layer. All
// user-facing surface APIs should use JitFuture instead.
using FutureMessage = torch::utils::Future<Message>;
using JitFuture = c10::ivalue::Future;

} // namespace rpc
} // namespace distributed
} // namespace torch