1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339
|
#pragma once
#include <torch/csrc/distributed/rpc/message.h>
#include <torch/csrc/distributed/rpc/request_callback.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <cctype>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <thread>
namespace torch::distributed::rpc {
using DeviceMap = std::unordered_map<c10::Device, c10::Device>;
// Default RPC timeout
constexpr float kDefaultRpcTimeoutSeconds = 60;
// Unset RPC timeout. This is the value agent::send() will have if user does not
// pass in a specific timeout, and indicates that we must use the default
// timeout for RPCs.
constexpr float kUnsetRpcTimeout = -1;
constexpr auto kDefaultInitMethod = "env://";
constexpr float kSecToMsConversion = 1000;
constexpr auto kRpcTimeoutErrorStr =
"RPC ran for more than set timeout ({} ms) and will now be marked with an error";
using steady_clock_time_point =
std::chrono::time_point<std::chrono::steady_clock>;
// Input is qualified name string, output is JIT StrongTypePtr
// Same as jit::TypeResolver, did not import jit::TypeResolver to here
// because it could introduce cyclic dependencies.
using TypeResolver =
std::function<c10::StrongTypePtr(const c10::QualifiedName&)>;
struct TORCH_API RpcBackendOptions {
RpcBackendOptions()
: RpcBackendOptions(kDefaultRpcTimeoutSeconds, kDefaultInitMethod) {}
RpcBackendOptions(float rpcTimeoutSeconds, std::string initMethod)
: rpcTimeoutSeconds(rpcTimeoutSeconds),
initMethod(std::move(initMethod)) {
TORCH_CHECK(rpcTimeoutSeconds >= 0, "RPC Timeout must be non-negative");
}
float rpcTimeoutSeconds;
std::string initMethod;
};
// A globally unique ID to identify an RpcAgent
struct TORCH_API WorkerInfo : torch::CustomClassHolder {
WorkerInfo(std::string name, int64_t id);
WorkerInfo(std::string name, worker_id_t id);
bool operator==(const WorkerInfo& rhs) {
return (id_ == rhs.id_) && (name_ == rhs.name_);
}
static constexpr size_t MAX_NAME_LEN = 128;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string name_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const worker_id_t id_;
};
struct TORCH_API RegisterWorkerInfoOnce {
RegisterWorkerInfoOnce();
};
TORCH_API std::ostream& operator<<(
std::ostream& os,
const WorkerInfo& workerInfo);
// Struct for options to configure the RPC Retry protocol.
struct TORCH_API RpcRetryOptions {
// Using a default constructor like all other Options structs in the RPC
// codebase. TORCH_CHECKs for input validation are done in the
// sendWithRetries function.
RpcRetryOptions() = default;
// Maximum number of times we will retry the RPC
int maxRetries{5};
// Initial duration between consecutive RPC send attempts
std::chrono::milliseconds rpcRetryDuration{std::chrono::milliseconds(1000)};
// Constant for exponential backoff used while calculating future wait
// durations
float retryBackoff{1.5};
};
// Struct that stores all the metadata needed to retry a given RPC.
struct TORCH_API RpcRetryInfo {
RpcRetryInfo(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
c10::intrusive_ptr<JitFuture> originalFuture,
int retryCount,
RpcRetryOptions options)
: to_(to),
message_(std::move(message)),
originalFuture_(std::move(originalFuture)),
retryCount_(retryCount),
options_(options) {}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const WorkerInfo& to_;
c10::intrusive_ptr<Message> message_;
// Future that is returned to the caller of sendWithRetries().
c10::intrusive_ptr<JitFuture> originalFuture_;
// Number of send attempts completed so far.
int retryCount_;
RpcRetryOptions options_;
};
// ``RpcAgent`` is the base class for sending and receiving RPC messages. It
// provides a unified ``send`` API for both request and response messages, and
// will invoke the given ``RequestCallback`` to process received requests. It
// should immediately become ready to serve request and accept response after
// construction.
class TORCH_API RpcAgent {
public:
// `WorkerInfo` is the globally unique identifier for this RpcAgent instance.
// It contains a ``name_`` field and an ``id_`` field. ``name_`` is the
// globally unique name for this ``RpcAgent``. It is up to the ``RpcAgent``
// implementation to determine how to resolve names. ``id_`` is the globally
// unique ID for this ``RpcAgent``. This should be determined by the
// ``RpcAgent`` implementation.
// The ``RequestCallback`` will be invoked to handle received requests. This
// ``RpcAgent`` base class makes no assumption on the thread-safeness of the
// ``RequestCallback``. ``RpcAgent`` implementations need to make sure that
// its threading model conform to ``RequestCallback``'s requirement.
// NB: RpcAgent implementations should not start serving requests until
// ``start()`` is called, as there could be other contexts that have not been
// initialized yet at this time.
RpcAgent(
WorkerInfo id,
std::unique_ptr<RequestCallback> cb,
std::chrono::milliseconds rpcTimeout);
virtual ~RpcAgent();
// Send a message to the ``RpcAgent`` of id ``to`` and returns a
// ``JitFuture`` ptr. The implementation must be asynchronous, i.e., it
// cannot block until it receives the response.
//
// If ``message.isRequest()`` is true, the ``JitFuture`` will be
// completed when the response arrives. For other message types, the Future
// should be ignored by the caller.
virtual c10::intrusive_ptr<JitFuture> send(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
const float rpcTimeoutSeconds = kUnsetRpcTimeout,
const DeviceMap& deviceMap = {}) = 0;
// Retries sending the message up to maxRetries times until an ACK is
// received. The duration between consecutive sends is increased over
// time using an exponential backoff algorithm.
//
// Sends ``message`` to the ``RpcAgent`` of id ``to`` and returns a
// ``JitFuture`` ptr, just like send(). Caller can specify the maximum
// number of retries for this RPC (default is 5), initial duration between
// sends (default is 1000ms), and backoff constant (default is 1.5) by
// passing in the RpcRetryOptions struct. This API might end up
// executing a method twice on the remote end (it does not guarantee
// exactly-once semantics). Therefore, the user must ensure their requests
// are idempotent.
c10::intrusive_ptr<JitFuture> sendWithRetries(
const WorkerInfo& to,
c10::intrusive_ptr<Message> message,
RpcRetryOptions retryOptions = RpcRetryOptions());
// Return a reference to the ``WorkerInfo`` of this RpcAgent.
// NB: not using ``std::optional<const std::string&>`` here because we might
// need to create a separate RPC API lib and avoid forcing all ``RpcAgent``
// implementations to depend on libtorch.
const WorkerInfo& getWorkerInfo() const;
// Return a reference to the ``WorkerInfo`` of the given ``workerName``.
virtual const WorkerInfo& getWorkerInfo(
const std::string& workerName) const = 0;
virtual const WorkerInfo& getWorkerInfo(worker_id_t id) const = 0;
virtual std::vector<WorkerInfo> getWorkerInfos() const = 0;
// Retrieve the timeout for all RPCs.
inline std::chrono::milliseconds getRpcTimeout() const {
return rpcTimeout_.load();
}
// Set the timeout for all RPCs
inline void setRpcTimeout(const std::chrono::milliseconds& rpcTimeout) {
rpcTimeout_.store(rpcTimeout);
}
// Call sync and join all internal threads. This method should be called
// before every RPC process exits.
virtual void join(bool shutdown = false, float timeout = 0) = 0;
// Synchronize the this process with other ``RpcAgent`` processes. Block until
// all ``RpcAgent``s reach this method and send all pending messages.
virtual void sync() = 0;
// Sets up backend-agnostic state for accepting requests. Currently, this
// entails setting rpcAgentRunning_ to true, creating the retry thread, and
// calling the backend's startImpl.
void start();
// Derived classes must override this function to start accepting requests.
// This is used to initialize any backend-specific state. Users must call
// start, not startImpl, to initialize the RPC Agent.
virtual void startImpl() = 0;
// Stop accepting requests and shutdown the RPC framework as soon as possible
// by terminating all RPC threads.
void shutdown();
// Derived classes must override this function to start accepting requests.
// THis is used to clean up any backend-specific state. Users must call
// shutdown, not shutdownImpl, to shutdown the RPC Agent.
virtual void shutdownImpl() = 0;
// Check if current RPC agent is set.
static bool isCurrentRpcAgentSet();
// Retrieve the valid current RPC agent.
static std::shared_ptr<RpcAgent> getCurrentRpcAgent();
// Set the current RPC agent.
static void setCurrentRpcAgent(std::shared_ptr<RpcAgent> rpcAgent);
// Retrieve metrics as KV map
virtual std::unordered_map<std::string, std::string> getMetrics() = 0;
// Retrieve debug info in addition to metrics as KV map
virtual std::unordered_map<std::string, std::string> getDebugInfo();
// Flag to control whether GIL wait times
// should be profiled or not.
void enableGILProfiling(bool flag);
// Retrieve wheher we should profile GIL wait times or not.
bool isGILProfilingEnabled();
// Set type resolver that will be passed to JIT pickler to resolver type Ptr
// based on type str.
void setTypeResolver(std::shared_ptr<TypeResolver> typeResolver);
// Get the type resolver
std::shared_ptr<TypeResolver> getTypeResolver();
// Retrieves the device map for the provided destination worker.
virtual DeviceMap getDeviceMap(const WorkerInfo& dst) const;
// Retrieve the (non-CPU) devices that are supported by the agent.
virtual const std::vector<c10::Device>& getDevices() const;
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const WorkerInfo workerInfo_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
const std::unique_ptr<RequestCallback> cb_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::atomic<std::chrono::milliseconds> rpcTimeout_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::atomic<bool> profilingEnabled_;
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::shared_ptr<TypeResolver> typeResolver_;
// Atomic boolean indicating whether this agent is running. It controls
// whether several background threads should be running. It is set in
// RpcAgent::start() and unset in the derived class shutdown().
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::atomic<bool> rpcAgentRunning_;
private:
static std::shared_ptr<RpcAgent> currentRpcAgent_;
// Add GIL wait time data point to metrics
virtual void addGilWaitTime(const std::chrono::microseconds gilWaitTime) = 0;
friend class PythonRpcHandler;
// Map that stores metadata for RPC's that may need to be re-tried as well as
// the timepoint at which we should re-try them.
std::map<
steady_clock_time_point,
std::unordered_set<std::shared_ptr<RpcRetryInfo>>>
rpcRetryMap_;
// Thread that checks for retryable RPC's in the rpcRetryMap_ and sleeps until
// the next unACKed RPC's timeout has expired.
std::thread rpcRetryThread_;
// Function that rpcRetryThread_ calls in a loop as long as RpcAgent is
// running.
void retryExpiredRpcs();
// This is the callback attached to futures corresponding to send retries.
// This handles 3 cases: 1). send was completed, 2). send failed with an
// error and we've done maxRetries failed send attempts, and 3). send
// failed with an error and we have more retries to go. In case 1, we mark
// the original future as complete. In case 2, we mark the future with an
// error and do not retry again. In case 3, we move the RpcRetryInfo struct
// to another time point in the map to schedule the RPC for a future send.
void rpcRetryCallback(
JitFuture& message,
steady_clock_time_point newTime,
std::shared_ptr<RpcRetryInfo> earliestRpc);
// Function that uses the exponential backoff algorithm to compute the next
// time point to retry a given RPC.
inline steady_clock_time_point computeNewRpcRetryTime(
RpcRetryOptions& options,
int retryCount) {
// The exponential backoff algorithm being used here is:
// newTime = timeNow + (retryDuration * (backoffConstant ^ retryCount)).
std::chrono::milliseconds timedelta =
std::chrono::duration_cast<std::chrono::milliseconds>(
options.rpcRetryDuration * pow(options.retryBackoff, retryCount));
return std::chrono::time_point_cast<std::chrono::milliseconds>(
std::chrono::steady_clock::now() + timedelta);
}
// Condition Variable to signal when the rpcRetryMap_ has been populated.
std::condition_variable rpcRetryMapCV_;
// Mutex to protect RpcRetryMap_.
std::mutex rpcRetryMutex_;
};
} // namespace torch::distributed::rpc
namespace std {
template <>
struct hash<torch::distributed::rpc::WorkerInfo> {
std::size_t operator()(
const torch::distributed::rpc::WorkerInfo& worker_info) const noexcept {
return worker_info.id_;
}
};
} // namespace std
|