File: message.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (138 lines) | stat: -rw-r--r-- 4,153 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#include <torch/csrc/distributed/rpc/message.h>

namespace torch {
namespace distributed {
namespace rpc {

Message::Message() = default;

Message::Message(
    std::vector<char>&& payload,
    std::vector<torch::Tensor>&& tensors,
    MessageType type)
    : payload_(std::move(payload)), tensors_(std::move(tensors)), type_(type) {}

Message::Message(
    std::vector<char>&& payload,
    std::vector<torch::Tensor>&& tensors,
    MessageType type,
    int64_t id)
    : payload_(std::move(payload)),
      tensors_(std::move(tensors)),
      type_(type),
      id_(id) {}

Message::Message(const Message& other) = default;

Message::Message(Message&& other) noexcept = default;

Message& Message::operator=(Message const& rhs) & {
  auto payload = rhs.payload_;
  auto tensors = rhs.tensors_;
  Message(std::move(payload), std::move(tensors), rhs.type_, rhs.id_)
      .swap(*this);
  return *this;
}

Message& Message::operator=(Message&& rhs) & {
  Message(std::move(rhs.payload_), std::move(rhs.tensors_), rhs.type_, rhs.id_)
      .swap(*this);
  return *this;
}

void Message::swap(Message& rhs) noexcept {
  std::swap(payload_, rhs.payload_);
  std::swap(tensors_, rhs.tensors_);
  std::swap(type_, rhs.type_);
  std::swap(id_, rhs.id_);
}

std::vector<char>&& Message::movePayload() && {
  return std::move(payload_);
}

std::vector<char>& Message::payload() {
  return payload_;
}

const std::vector<char>& Message::payload() const {
  return payload_;
}

std::vector<torch::Tensor>&& Message::moveTensors() && {
  return std::move(tensors_);
}

std::vector<torch::Tensor>& Message::tensors() {
  return tensors_;
}

const std::vector<torch::Tensor>& Message::tensors() const {
  return tensors_;
}

MessageType Message::type() const {
  return type_;
}

bool Message::isRequest() const {
  return MessageType::SCRIPT_CALL == type_ || // dist.rpc on builtin ops
      MessageType::PYTHON_CALL == type_ || // dist.rpc on Python UDFs
      MessageType::SCRIPT_REMOTE_CALL == type_ || // dist.remote on builtin ops
      MessageType::PYTHON_REMOTE_CALL == type_ || // dist.remote on Python UDFs
      // RRef related internal messages
      MessageType::SCRIPT_RREF_FETCH_CALL == type_ ||
      MessageType::PYTHON_RREF_FETCH_CALL == type_ ||
      MessageType::RREF_USER_DELETE == type_ ||
      MessageType::RREF_CHILD_ACCEPT == type_ ||
      MessageType::RREF_FORK_REQUEST == type_ ||
      // Autograd message
      MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
      MessageType::FORWARD_AUTOGRAD_REQ == type_ ||
      // Cleanup Autograd context request
      MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ == type_ ||
      // Run with profiling request
      MessageType::RUN_WITH_PROFILING_REQ == type_;
}

bool Message::isResponse() const {
  return MessageType::SCRIPT_RET == type_ || // ret of dist.rpc on builtin ops
      MessageType::PYTHON_RET == type_ || // ret of dist.rpc on Python UDFs
      MessageType::REMOTE_RET == type_ || // ret of dist.remote
      MessageType::SCRIPT_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
      MessageType::PYTHON_RREF_FETCH_RET == type_ || // ret on RRef::toHere()
      MessageType::EXCEPTION == type_ || // propagate back exceptions
      MessageType::RREF_ACK == type_ || // ret of other types
      // Autograd response
      MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
      MessageType::FORWARD_AUTOGRAD_RESP == type_ ||
      // Cleanup autograd context response
      MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP == type_ ||
      // Run with profiling response
      MessageType::RUN_WITH_PROFILING_RESP == type_;
}

int64_t Message::id() const {
  return id_;
}

void Message::setId(int64_t id) {
  id_ = id;
}

Message createExceptionResponse(const std::exception& e, int64_t id) {
  return createExceptionResponse(e.what(), id);
}

Message createExceptionResponse(const std::string& exceptionStr, int64_t id) {
  std::vector<char> payload(exceptionStr.begin(), exceptionStr.end());
  return Message(
      std::move(payload),
      std::vector<torch::Tensor>(),
      MessageType::EXCEPTION,
      id);
}

} // namespace rpc
} // namespace distributed
} // namespace torch