1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416
|
#pragma once
#include <memory>
#include <utility>
#include <vector>
#include <ATen/ATen.h>
#include <c10/macros/Macros.h>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/debug.h>
constexpr auto kBackendDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000);
namespace c10d {
class TORCH_API Backend : public torch::CustomClassHolder {
public:
// Backend Options is a base struct that defines the basic options
// when constructing a Backend. Each Backend subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.
struct TORCH_API Options : torch::CustomClassHolder {
explicit Options(
std::string backend,
std::chrono::milliseconds timeout = kBackendDefaultTimeout)
: timeout(timeout), backend(std::move(backend)) {}
~Options() override = default;
std::chrono::milliseconds timeout;
// backend name
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const std::string backend;
};
explicit Backend(int rank, int size);
~Backend() override = 0;
int getRank() const {
return rank_;
}
int getSize() const {
return size_;
}
// Returns an unique opaque ID of this backend that can be used to correlate
// with its collectives.
int64_t getID() const {
return reinterpret_cast<std::intptr_t>(this);
}
virtual bool supportsSplitting() const {
return false;
}
virtual void startCoalescing() {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not implement startCoalescing"));
}
virtual c10::intrusive_ptr<Work> endCoalescing() {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not implement endCoalescing"));
}
// Subclasses must override this method to return the backend name
virtual const std::string getBackendName() const {
TORCH_INTERNAL_ASSERT(false, "getBackendName is not implemented.");
}
virtual c10::intrusive_ptr<Work> broadcast(
std::vector<at::Tensor>& /* tensors */,
const BroadcastOptions& /* opts */ = BroadcastOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support broadcast"));
}
virtual c10::intrusive_ptr<Work> allreduce(
std::vector<at::Tensor>& /* tensors */,
const AllreduceOptions& /* opts */ = AllreduceOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support allreduce"));
}
virtual c10::intrusive_ptr<Work> allreduce_sparse(
std::vector<at::Tensor>& /* tensors */,
const AllreduceOptions& /* opts */ = AllreduceOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support allreduce sparse"));
}
virtual c10::intrusive_ptr<Work> allreduce_coalesced(
std::vector<at::Tensor>& /* tensors */,
const AllreduceCoalescedOptions& /* opts */ =
AllreduceCoalescedOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support allreduce_coalesced"));
}
virtual c10::intrusive_ptr<Work> reduce(
std::vector<at::Tensor>& /* tensors */,
const ReduceOptions& /* opts */ = ReduceOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support reduce"));
}
virtual c10::intrusive_ptr<Work> allgather(
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support allgather"));
}
// Gathers a single tensor inputBuffer into a single buffer outputBuffer that
// is interpreted as a contiguous collection of size inputBuffer * WORLD_SIZE.
// For implementers of ProcessGroup API and advanced users only.
// Note: this function will be deprecated in near future.
virtual c10::intrusive_ptr<Work> _allgather_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not support _allgather_base"));
}
// This function is deprecated and will be moved out of Backend to comms:
// * do not add dependencies on this function,
// * do not implement it in your Backend, implement _allgather_base
// instead.
virtual c10::intrusive_ptr<Work> allgather_coalesced(
std::vector<std::vector<at::Tensor>>& /* outputTensorLists */,
std::vector<at::Tensor>& /* inputTensors */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support allgather_coalesced"));
}
// This function is a coalesced version of `allgather_into_tensor` (currently
// still named as `_allgather_base`). Each tensor in the vector corresponds to
// an input/output of one `allgather_into_tensor` operation.
virtual c10::intrusive_ptr<Work> allgather_into_tensor_coalesced(
std::vector<at::Tensor>& /* outputs */,
std::vector<at::Tensor>& /* inputs */,
const AllgatherOptions& /* opts */ = AllgatherOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support allgather_into_tensor_coalesced"));
}
virtual c10::intrusive_ptr<Work> gather(
std::vector<std::vector<at::Tensor>>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const GatherOptions& /* opts */ = GatherOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support gather"));
}
virtual c10::intrusive_ptr<Work> scatter(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
const ScatterOptions& /* opts */ = ScatterOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support scatter"));
}
virtual c10::intrusive_ptr<Work> reduce_scatter(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<std::vector<at::Tensor>>& /* inputTensors */,
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not support reduce_scatter"));
}
virtual c10::intrusive_ptr<Work> _reduce_scatter_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support _reduce_scatter_base"));
}
// This function is a coalesced version of `reduce_scatter_tensor` (currently
// still named as `_reduce_scatter_base`). Each tensor in the vector
// corresponds to an input/output of one `reduce_scatter_tensor` operation.
virtual c10::intrusive_ptr<Work> reduce_scatter_tensor_coalesced(
std::vector<at::Tensor>& /* outputs */,
std::vector<at::Tensor>& /* inputs */,
const ReduceScatterOptions& /* opts */ = ReduceScatterOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ",
getBackendName(),
" does not support reduce_scatter_tensor_coalesced"));
}
virtual c10::intrusive_ptr<Work> alltoall_base(
at::Tensor& /* outputBuffer */,
at::Tensor& /* inputBuffer */,
std::vector<int64_t>& /* outputSplitSizes */,
std::vector<int64_t>& /* inputSplitSizes */,
const AllToAllOptions& /* opts */ = AllToAllOptions()) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not support alltoall_base"));
}
virtual c10::intrusive_ptr<Work> alltoall(
std::vector<at::Tensor>& /* outputTensors */,
std::vector<at::Tensor>& /* inputTensors */,
const AllToAllOptions& opts = AllToAllOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support alltoall"));
}
virtual void monitoredBarrier(
const BarrierOptions& /* unused */,
bool /* unused */ = false) {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not support monitoredBarrier, only GLOO supports monitored barrier."));
}
// Agrees on an initial sequence number for the whole group by having rank 0
// create it and broadcast it to other ranks using the store. Only implemented
// for GLOO and NCCL backends currently.
virtual void setSequenceNumberForGroup() {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not yet support sequence numbers."));
}
// Retrieves the current sequence number for the whole group, which should be
// in sync. If the returned number is not consistent across the group, it
// may indicate that there is some sort of collective desynchronization.
virtual uint64_t getSequenceNumberForGroup() {
auto backendName = getBackendName();
TORCH_CHECK(
false,
c10::str(
"Backend ",
backendName,
" does not yet support sequence numbers."));
}
virtual c10::intrusive_ptr<Work> send(
std::vector<at::Tensor>& /* tensors */,
int /* dstRank */,
int /* tag */) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support send"));
}
virtual c10::intrusive_ptr<Work> recv(
std::vector<at::Tensor>& /* tensors */,
int /* srcRank */,
int /* tag */) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support recv"));
}
virtual c10::intrusive_ptr<Work> recvAnysource(
std::vector<at::Tensor>& /* tensors */,
int /* tag */) {
TORCH_CHECK(
false,
c10::str(
"Backend ", getBackendName(), " does not support recvAnysource"));
}
virtual c10::intrusive_ptr<Work> barrier(
const BarrierOptions& /* opts */ = BarrierOptions()) {
TORCH_CHECK(
false,
c10::str("Backend ", getBackendName(), " does not support barrier"));
}
virtual void registerOnCompletionHook(
std::function<void(std::shared_ptr<WorkInfo>)>&& hook) {
TORCH_CHECK(
false,
"Only ProcessGrouppNCCL supports onCompletion hook, but got ",
getBackendName(),
" backend.");
}
virtual void waitForPendingWorks() {
TORCH_CHECK(
false,
"Only ProcessGrouppNCCL supports waitForPendingWorks, but got ",
getBackendName(),
" backend.");
}
virtual void enableCollectivesTiming() {
TORCH_CHECK(
false,
"Backend ",
getBackendName(),
" is missing implementation of enableCollectivesTiming.");
}
bool hasHooks() const {
return onCompletionHook_ != nullptr;
}
// Do not call this directly, use ProcessGroup::setGroupName instead.
void setGroupUid(const std::string& pg_uid) {
pg_uid_ = pg_uid;
}
const std::string& getGroupUid() const {
return pg_uid_;
}
void setGroupDesc(const std::string& desc) {
pg_desc_ = desc;
}
const std::string& getGroupDesc() const {
return pg_desc_;
}
// See similar functions in ProcessGroup.hpp for context.
std::optional<at::Device> getBoundDeviceId() const {
return bound_device_id_;
}
// Perform an eager connect to the specified device if the backend supports
// it.
virtual void eagerConnectSingleDevice(at::Device device) {
// no-op in the default case; this is an optimization some
// backends may perform
}
void setBoundDeviceId(std::optional<at::Device> device) {
if (device) {
TORCH_CHECK(device->has_index(), "setBoundDeviceId must have an index");
}
bound_device_id_ = device;
}
protected:
// Implementations of this interface need to call this to setup
// appropriate logging etc.
void init();
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int rank_;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const int size_;
// Debug level setting. It is parsed once when ProcessGroup is constructed and
// remains the same across use of this process group.
DebugLevel dist_debug_level_;
std::string pg_uid_;
std::string pg_desc_;
std::function<void(std::shared_ptr<WorkInfo>)> onCompletionHook_;
std::optional<at::Device> bound_device_id_;
};
} // namespace c10d
|