File: faulty_process_group_agent.cpp

package info (click to toggle)
pytorch 1.7.1-7
  • links: PTS, VCS
  • area: main
  • in suites: bullseye
  • size: 80,340 kB
  • sloc: cpp: 670,830; python: 343,991; ansic: 67,845; asm: 5,503; sh: 2,924; java: 2,888; xml: 266; makefile: 244; ruby: 148; yacc: 144; objc: 51; lex: 44
file content (149 lines) | stat: -rw-r--r-- 6,005 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#include <torch/csrc/distributed/rpc/testing/faulty_process_group_agent.h>
#include <torch/csrc/distributed/rpc/request_callback_impl.h>
#include <torch/csrc/distributed/rpc/utils.h>

namespace torch {
namespace distributed {
namespace rpc {

std::string fromVec(const std::vector<char>& vec) {
  return std::string(vec.begin(), vec.end());
}

FaultyProcessGroupAgent::FaultyProcessGroupAgent(
    std::string workerName,
    std::shared_ptr<c10d::ProcessGroup> pg,
    int numSendRecvThreads,
    std::chrono::milliseconds rpcTimeout,
    const std::vector<std::string>& messagesToFail,
    const std::unordered_map<std::string, float>& messageTypesToDelay,
    int failNumSends)
    : ProcessGroupAgent(
          std::move(workerName),
          std::move(pg),
          numSendRecvThreads,
          rpcTimeout,
          std::make_unique<RequestCallbackImpl>()),
      failNumSends_(failNumSends),
      messageTypesToFail_(parseMessagesToFailInput(messagesToFail)),
      messageTypesToDelay_(parseMessagesToDelay(messageTypesToDelay)) {}

std::vector<MessageType> FaultyProcessGroupAgent::parseMessagesToFailInput(
    const std::vector<std::string>& messagesToFail) const {
  // Since we can only pass strings corresponding to the Message Types from the
  // python tests, we must parse the list of strings and resolve the actual
  // types. We will then check this list of types in the send function to
  // determine whether we should fail or not.
  std::vector<MessageType> messageTypesToFail;
  messageTypesToFail.reserve(messagesToFail.size());
  for (const auto& msgString : messagesToFail) {
    messageTypesToFail.push_back(messageStringToType(msgString));
  }
  return messageTypesToFail;
}

std::unordered_map<MessageType, float, std::hash<int>> FaultyProcessGroupAgent::
    parseMessagesToDelay(const std::unordered_map<std::string, float>&
                             messageTypesToDelay) const {
  std::unordered_map<MessageType, float, std::hash<int>> delayMessages;
  for (const auto& messagePair : messageTypesToDelay) {
    float delay = messagePair.second;
    TORCH_CHECK(
        delay >= 0,
        "Delays passed to FaultyProcessGroupAgent must be non-negative.")
    delayMessages.insert({messageStringToType(messagePair.first), delay});
  }
  return delayMessages;
}

std::shared_ptr<FutureMessage> FaultyProcessGroupAgent::send(
    const WorkerInfo& to,
    Message&& message,
    const float rpcTimeoutSeconds) {
  // We only fail control messages that have been specified by the test case.
  // For all other messages, we just send them without any failures.
  if (!shouldFailMessage(message.type())) {
    return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
  }
  // This send function checks the failMessageCountMap_ to check whether
  // we must fail the next send. If the send must be failed, we set an error
  // on the returned future immediately and increment the counter in the map,
  // otherwise we just call the ProcessGroupAgent send.
  const auto key = fromVec(message.payload());
  std::unique_lock<std::mutex> lock(failMapMutex_);
  auto it = failMessageCountMap_.find(key);
  if (it == failMessageCountMap_.end()) {
    failMessageCountMap_[key] = 0;
  }
  if (failMessageCountMap_[key] < failNumSends_) {
    failMessageCountMap_[key]++;
    lock.unlock();
    auto fm = std::make_shared<FutureMessage>();
    fm->setError(makeRPCError(
        c10::str("Send attempt failed intentionally for ", key),
        RPCErrorType::INTENTIONAL_FAILURE));
    return fm;
  } else {
    lock.unlock();
    return ProcessGroupAgent::send(to, std::move(message), rpcTimeoutSeconds);
  }
}

void FaultyProcessGroupAgent::enqueueSend(SendWork work) {
  float msgDelay = getDelayForMessage(work.message_.type());
  if (msgDelay != 0) {
    // Sleep for the specified delay for the message.
    std::this_thread::sleep_for(std::chrono::milliseconds(
        static_cast<int>(msgDelay * kSecToMsConversion)));
  }
  ProcessGroupAgent::enqueueSend(std::move(work));
}

void FaultyProcessGroupAgent::sendToSelf(Message&& message) {
  float msgDelay = getDelayForMessage(message.type());
  if (msgDelay != 0) {
    // Sleep for the specified delay for the message.
    std::this_thread::sleep_for(std::chrono::milliseconds(
        static_cast<int>(msgDelay * kSecToMsConversion)));
  }
  ProcessGroupAgent::sendToSelf(std::move(message));
}

bool FaultyProcessGroupAgent::shouldFailMessage(MessageType type) const {
  // Return true if the input message type is in the messageTypesToFail_ list
  return (
      std::find(messageTypesToFail_.begin(), messageTypesToFail_.end(), type) !=
      messageTypesToFail_.end());
}

float FaultyProcessGroupAgent::getDelayForMessage(MessageType type) const {
  const auto& it = messageTypesToDelay_.find(type);
  return it == messageTypesToDelay_.end() ? 0 : it->second;
}

MessageType FaultyProcessGroupAgent::messageStringToType(
    const std::string& messageString) const {
  // Lazily constructed map that returns string to message type mapping
  static std::unordered_map<std::string, MessageType> msgMap = {
      {"RREF_FORK_REQUEST", MessageType::RREF_FORK_REQUEST},
      {"RREF_CHILD_ACCEPT", MessageType::RREF_CHILD_ACCEPT},
      {"RREF_USER_DELETE", MessageType::RREF_USER_DELETE},
      {"CLEANUP_AUTOGRAD_CONTEXT_REQ",
       MessageType::CLEANUP_AUTOGRAD_CONTEXT_REQ},
      {"PYTHON_REMOTE_CALL", MessageType::PYTHON_REMOTE_CALL},
      {"SCRIPT_REMOTE_CALL", MessageType::SCRIPT_REMOTE_CALL},
      {"PYTHON_CALL", MessageType::PYTHON_CALL},
      {"SCRIPT_CALL", MessageType::SCRIPT_CALL},
      {"PYTHON_RREF_FETCH_CALL", MessageType::PYTHON_RREF_FETCH_CALL},
      {"SCRIPT_RREF_FETCH_CALL", MessageType::SCRIPT_RREF_FETCH_CALL}};
  const auto& it = msgMap.find(messageString);
  TORCH_CHECK(
      it != msgMap.end(),
      "No mapping to rpc::MessageType exists for ",
      messageString);
  return it->second;
}

} // namespace rpc
} // namespace distributed
} // namespace torch