1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
|
#pragma once
#include <c10/core/thread_pool.h>
#include <c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <atomic>
#include <thread>
namespace torch {
namespace distributed {
namespace rpc {
constexpr auto kDefaultNumSendRecvThreads = 4;
struct ProcessGroupRpcBackendOptions : public RpcBackendOptions {
ProcessGroupRpcBackendOptions(
int num_send_recv_threads,
float rpc_timeout,
std::string init_method)
: RpcBackendOptions(rpc_timeout, init_method),
numSendRecvThreads(num_send_recv_threads) {
TORCH_CHECK(
num_send_recv_threads > 0,
"Cannot create ProcessGroup RPC backend with ",
num_send_recv_threads,
" threads in the thread-pool.");
}
int numSendRecvThreads;
};
// SendWork and RecvWork will be put into a task queue, and later picked up by
// worker threads from the same ThreadPool.
struct SendWork {
SendWork(const WorkerInfo& to, Message&& message)
: to_(to), message_(message) {}
const WorkerInfo& to_;
Message message_;
};
// SendWork wraps a Message and RecvWork wraps a Tensor. The difference here is
// to allow us to run serialization/deserialization in the worker threads.
struct RecvWork {
RecvWork(
const WorkerInfo& from,
MessageType type,
int64_t id,
torch::Tensor&& payload)
: from_(from), type_(type), id_(id), payload_(payload) {}
const WorkerInfo& from_;
const MessageType type_;
const int64_t id_;
torch::Tensor payload_;
};
class TORCH_API ProcessGroupAgent : public RpcAgent {
public:
ProcessGroupAgent(
std::string workerName,
std::shared_ptr<c10d::ProcessGroup> pg,
int numSendRecvThreads,
std::chrono::milliseconds rpcTimeout,
std::unique_ptr<RequestCallback> cb);
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
const WorkerInfo& getWorkerInfo(worker_id_t id) const override;
std::vector<WorkerInfo> getWorkerInfos() const override;
void join() override;
void sync() override;
void startImpl() override;
void shutdownImpl() override;
~ProcessGroupAgent() override;
std::unordered_map<std::string, std::string> getMetrics() override;
protected:
// This method wraps the destination information and the message into a
// SendWork object, and put the SendWork into a queue. Another thread will
// consume SendWork from the queue and send it out.
std::shared_ptr<FutureMessage> send(
const WorkerInfo& to,
Message&& message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout) override;
// put SendWork into a queue and notify the worker thread
virtual void enqueueSend(SendWork work);
// Bypass handleSend() logic and send a message to self rank
virtual void sendToSelf(Message&& message);
private:
class MessageCounter {
public:
explicit MessageCounter(int worldSize);
void increment(int dst);
std::vector<int64_t> snapshot();
private:
std::vector<int64_t> counters_;
std::mutex mutex_;
};
// TODO: this class should inherit from a MetricsTracker, and can be extended
// to track num_sends, recvs, average size of messages, etc.
struct AverageMetricsTracker {
std::string key_;
uint64_t currentSum_;
uint64_t currentCount_;
explicit AverageMetricsTracker(
std::string key,
uint64_t currentSum = 0,
uint64_t currentCount = 0);
void addData(uint64_t dataPoint);
double computeAverage();
};
// The FutureInfo struct stores a shared_ptr to the future, as well as
// additional information to manage timeouts and destination information,
// which is needed for termination detection.
struct FutureInfo {
std::shared_ptr<FutureMessage> future_;
steady_clock_time_point endTime_;
int dstRank_;
std::chrono::milliseconds timeout_;
FutureInfo(
const std::shared_ptr<FutureMessage>& future,
const steady_clock_time_point& endTime,
int dstRank,
const std::chrono::milliseconds timeout)
: future_(future),
endTime_(endTime),
dstRank_(dstRank),
timeout_(timeout) {}
FutureInfo() = delete;
};
void collectNames();
// handle a SendWork request. This serializes the payload inside the work
// object, and sends the message to the receiver using the underlying
// ProcessGroup.
void handleSend(const SendWork& work);
// put RecvWork into a queue and notify the worker thread
void enqueueRecv(RecvWork work);
// handle a RecvWork request. Return true if we should increment recvCounts,
// false if not (i.e. if the RPC timed out and we are getting a result after
// the timeout). This ensures that the messages accounted for in
// hasPendingMessage() are tallied properly during a graceful shutdown.
bool handleRecv(RecvWork& work);
// Loop that receives and processes messages
void listenLoopInternal();
// Calls listenLoopInternal and handles errors such as timeouts on the
// process group.
void listenLoop();
// exception_pointer correspnding to an exception raised in listenLoop (if
// there is one), and lock to guard access.
std::exception_ptr listenLoopException_;
std::mutex listenLoopExceptionMutex_;
// poll for timed out RPCs
void pollTimedOutRPCs();
// process timed out futures
const std::vector<FutureInfo> processTimedOutFutures();
// compute the remaining time for an RPC, given its end time.
const std::chrono::milliseconds getRPCRemainingTime(
const std::chrono::milliseconds& rpcEndTime) const;
// a helper function to mark a future in the futures_ map with a message. The
// future is marked with the passed in message, and then removed from the
// futures_ map. It is also removed from the futureTimeouts_ map since these
// maps are kept in sync.
void markFutureWithError(Message& message);
void markFutureWithError(int64_t id, std::string errorMsg);
// Note [Termination Detection]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
//
// RpcAgent implementations must properly detect termination. Otherwise, it
// would result in message loss, RRef leak, or process hang. It is not
// sufficient to just wait for the thread pool to finish processing all tasks
// after all processes hit the join function. There could be nested rpc/remote
// calls, meaning that an empty task queue in the thread pool does not mean
// there will be no tasks added in the future. Moreover, in the listenLoop,
// there is a period of time when the message has been received but not yet
// inserted into the thread pool, which also suggests that the empty task
// queue is not a good indicator for termination.
//
// To detect termination, each ProcessGroupAgent maintains a sent message
// counter and a received message counter. The sent message counter is
// incremented whenever a message is sent, and the receive message counter is
// only incremented when a message has been processed. During termination, all
// ProcessGroupAgent instances run an allgather to collect counters from all
// peers, which means that all agents will have a consistent view on the
// message count snapshot. They would only terminate if all sent/received
// message counters match.
bool hasPendingMessage();
int64_t nextId() {
return ++nextId_;
}
std::shared_ptr<c10d::ProcessGroup> pg_;
// worker name -> rank
std::unordered_map<std::string, worker_id_t> nameMap_;
std::vector<WorkerInfo> allWorkerInfo_;
// record the number of messages sent to and received from each peer. The recv
// counter is only marked after the message is processed. Join uses allgather
// to collect all counts from all peers, uses these counters to detect global
// termination and only exit when all sent messages are processed.
MessageCounter sendCounts_;
MessageCounter recvCounts_;
std::atomic<int64_t> nextId_;
// one mutex per ProcessGroup rank, as ProcessGroup::send is not thread-safe
// when using the same tag.
std::vector<std::mutex> sendMutexes_;
std::thread listenerThread_;
// A thread to poll existing futures and check for timed out ones.
std::thread futureTimeoutThread_;
// Lock and shared ptr to currently pending work, set in listenloop() and
// interruptible in shutdown().
std::mutex recvWorkMutex_;
std::shared_ptr<c10d::ProcessGroup::Work> recvWork_;
// Map of dst rank to current oustanding sends that we are waiting on. In the
// case of a call to ::shutdown() while we are still waiting on these sends,
// the pending sends contained in this map will be aborted, allowing the
// waiting thread to be unblocked.
std::unordered_map<
worker_id_t,
std::set<std::shared_ptr<c10d::ProcessGroup::Work>>>
currentPendingSends_;
// Lock to serialize access to the above map.
std::mutex pendingSendMutex_;
// A threadPool that processing both SendWork and RecvWork. There are two
// motivations for adding a ThreadPool:
// (1) RPC serialization/deserialization and processing can be expensive,
// hence using multiple threads to speed it up.
// (2) The current RPC API does not support asynchronous UDFs, e.g., UDFs can
// not yield in the middle of execution to wait for IO, and resume the IO
// is done. This would result in deadlocks when we have nested RPC calls.
// NB: Ideally, this should be addressed by supporting asynchronous UDF.
// This is just a temporary solution for (2).
ThreadPool threadPool_;
// Atomic to indicate whether the timeout thread is enabled.
std::atomic<bool> timeoutThreadEnabled_;
// Mapping of request id to FutureInfo struct.
std::unordered_map<int64_t, FutureInfo> futures_;
// A map to keep track of when futures time out. The map is keyed by the time
// (millisecond level precision) the future will expire. This is so that timed
// out futures can be efficiently cleaned up, and we can quickly exit if we
// find a future that has not timed out. The values correspond to an
// unordered_set of future ids that started at that time. This map must be
// kept in sync with the above futures_ map.
std::map<steady_clock_time_point, std::unordered_set<int64_t>>
futureTimeouts_;
mutable std::mutex futureMutex_;
mutable std::condition_variable futureCV_;
// CV to wake up watchdog thread that watches for timed out futures.
std::condition_variable futureTimeoutCV_;
// Metrics tracked for ProcessGroupAgent.
enum ProcessGroupAgentMetrics {
GIL_WAIT_TIME = 0,
N_METRICS,
};
std::mutex metricsMutex_;
std::vector<std::unique_ptr<AverageMetricsTracker>> metrics_;
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
std::atomic<int32_t> clientActiveCalls_{0};
std::atomic<int32_t> serverActiveCalls_{0};
std::atomic<int32_t> serverActiveAsyncCalls_{0};
};
} // namespace rpc
} // namespace distributed
} // namespace torch
|