File: rpc_with_autograd.cpp

package info (click to toggle)
pytorch 1.13.1%2Bdfsg-4
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 139,252 kB
  • sloc: cpp: 1,100,274; python: 706,454; ansic: 83,052; asm: 7,618; java: 3,273; sh: 2,841; javascript: 612; makefile: 323; xml: 269; ruby: 185; yacc: 144; objc: 68; lex: 44
file content (180 lines) | stat: -rw-r--r-- 5,953 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include <c10/util/C++17.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/utils.h>
#include <torch/csrc/jit/serialization/pickle.h>
#include <torch/csrc/utils/byte_order.h>

namespace torch {
namespace distributed {
namespace autograd {

using rpc::Message;
using rpc::MessageType;
using rpc::RpcCommandBase;
using rpc::worker_id_t;

RpcWithAutograd::RpcWithAutograd(
    worker_id_t fromWorkerId,
    MessageType messageType,
    const AutogradMetadata& autogradMetadata,
    c10::intrusive_ptr<rpc::Message> wrappedMessage,
    rpc::DeviceMap deviceMap)
    : fromWorkerId_(fromWorkerId),
      messageType_(messageType),
      autogradMetadata_(autogradMetadata),
      wrappedMessage_(std::move(wrappedMessage)),
      deviceMap_(std::move(deviceMap)) {
  TORCH_INTERNAL_ASSERT(
      messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
      messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
  tensors_ = wrappedMessage_->tensors();
  wrappedMessageType_ = wrappedMessage_->type();
}

RpcWithAutograd::RpcWithAutograd(
    worker_id_t fromWorkerId,
    MessageType messageType,
    const AutogradMetadata& autogradMetadata,
    std::unique_ptr<RpcCommandBase> wrappedRpc,
    MessageType wrappedMessageType,
    std::vector<torch::Tensor> tensors,
    rpc::DeviceMap deviceMap)
    : fromWorkerId_(fromWorkerId),
      messageType_(messageType),
      autogradMetadata_(autogradMetadata),
      wrappedRpc_(std::move(wrappedRpc)),
      wrappedMessageType_(wrappedMessageType),
      tensors_(std::move(tensors)),
      deviceMap_(std::move(deviceMap)) {
  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
  TORCH_INTERNAL_ASSERT(
      messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
      messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
}

c10::intrusive_ptr<Message> RpcWithAutograd::toMessageImpl() && {
  auto messageId = wrappedMessage_->id();
  auto wrappedMessageType = wrappedMessage_->type();

  auto payload = std::move(*wrappedMessage_).movePayload();
  TORCH_INTERNAL_ASSERT(!payload.empty());

  // Convert deviceMap to c10::Dict for serialization.
  c10::Dict<std::string, std::string> deviceMap;
  for (const auto& mapEntry : deviceMap_) {
    deviceMap.insert(mapEntry.first.str(), mapEntry.second.str());
  }

  std::vector<at::IValue> ivalues{
      wrappedMessageType,
      autogradMetadata_.autogradContextId,
      autogradMetadata_.autogradMessageId,
      fromWorkerId_,
      deviceMap};

  // Now pickle using JIT pickler.
  std::vector<torch::Tensor> tensorTable;
  std::vector<char> additionalPayload =
      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);

  // We shouldn't have any tensors!
  TORCH_INTERNAL_ASSERT(tensorTable.empty());

  // This wraps additionalPayload into payload and takes care of resizing,
  // encoding.
  rpc::writeWrappedPayload(payload, additionalPayload);

  return c10::make_intrusive<Message>(
      std::move(payload), std::move(tensors_), messageType_, messageId);
}

std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
    const Message& message) {
  MessageType originalMessageType = message.type();
  TORCH_INTERNAL_ASSERT(
      MessageType::FORWARD_AUTOGRAD_REQ == originalMessageType ||
      MessageType::FORWARD_AUTOGRAD_RESP == originalMessageType);

  std::vector<torch::Tensor> tensors = message.tensors();
  int64_t messageId = message.id();
  // Decode message type, autograd context id, autograd message id and worker
  // id from which we received this message.
  auto payload = message.payload();
  auto tupleElements = rpc::readWrappedPayload(payload, message);

  // Gather all the fields.
  TORCH_INTERNAL_ASSERT(tupleElements.size() == 5);
  MessageType wrappedMessageType =
      static_cast<MessageType>(tupleElements[0].toInt());
  AutogradMetadata autogradMetadata(
      tupleElements[1].toInt(), tupleElements[2].toInt());
  worker_id_t workerId = tupleElements[3].toInt();
  auto c10DeviceMap =
      tupleElements[4].to<c10::Dict<std::string, std::string>>();

  // Convert to regular map.
  rpc::DeviceMap deviceMap;
  for (const auto& mapEntry : c10DeviceMap) {
    deviceMap.insert({mapEntry.key(), mapEntry.value()});
  }

  // Create new message type and build wrapped RPC.
  auto wrappedMessage = c10::make_intrusive<Message>(
      std::move(payload), std::move(tensors), wrappedMessageType, messageId);

  std::unique_ptr<RpcCommandBase> wrappedRpc;
  if (originalMessageType == MessageType::FORWARD_AUTOGRAD_REQ) {
    wrappedRpc = deserializeRequest(*wrappedMessage);
  } else {
    wrappedRpc = deserializeResponse(*wrappedMessage, wrappedMessageType);
  }

  return std::make_unique<RpcWithAutograd>(
      workerId,
      originalMessageType,
      autogradMetadata,
      std::move(wrappedRpc),
      wrappedMessageType,
      wrappedMessage->tensors(),
      deviceMap);
}

std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
  return tensors_;
}

const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
  return autogradMetadata_;
}

RpcCommandBase& RpcWithAutograd::wrappedRpc() {
  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
  return *wrappedRpc_;
}

void RpcWithAutograd::setWrappedRpc(
    std::unique_ptr<RpcCommandBase> wrappedRpc) {
  wrappedRpc_ = std::move(wrappedRpc);
}

std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
  return std::move(wrappedRpc_);
}

MessageType RpcWithAutograd::wrappedMessageType() const {
  return wrappedMessageType_;
}

rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
  return fromWorkerId_;
}

const rpc::DeviceMap& RpcWithAutograd::deviceMap() {
  return deviceMap_;
}

} // namespace autograd
} // namespace distributed
} // namespace torch