1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
|
#include <torch/csrc/distributed/c10d/NCCLUtils.hpp>
#include <c10/util/env.h>
#ifdef USE_C10D_NCCL
#include <mutex>
#include <vector>
namespace c10d {
ncclComm_t NCCLComm::getNcclComm() {
LockType lock(mutex_);
if (aborted_) {
auto commFailureMsg = commFailureReason_ != std::nullopt
? c10::str(" Original reason for failure was: ", *commFailureReason_)
: "";
TORCH_CHECK_WITH(
DistBackendError,
false,
c10::str(
"NCCL communicator was aborted on rank ",
rank_,
". ",
commFailureMsg));
}
// In non-blocking mode, ensure comm is ready.
if (nonBlocking_) {
// Wait with long interval if communicator is being initialized.
bool longInterval = !initialized_;
waitReady(longInterval);
// ncclComm_ should be initialized by now
}
if (!initialized_) {
// TODO: see if we can consolidate other `initialized_` flipping here.
// Maintaining it elsewhere is some work.
initialized_ = true;
LOG(INFO) << "Rank " << rank_ << ": NCCL communicator " << repr()
<< " is initialized.";
}
return ncclComm_;
}
// Wait for the communicator to be ready. This is a blocking function.
// Arguments:
// longInterval: if true, wait with sleep of an interval; otherwise, wait
// with `sched_yield` which is faster (but acquires CPU more frequently).
void NCCLComm::waitReady(bool longInterval) {
LockType lock(mutex_);
if (aborted_)
return;
// If timeout is reached, throw an exception.
if (longInterval) {
C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt);
} else {
C10D_NCCL_CHECK_TIMEOUT(ncclInProgress, ncclComm_, std::nullopt);
}
}
// TODO: why do we have `!defined(FBCODE_CAFFE2)` here?
#if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2)
// last argument to split() API is not used to support
// multiple implementations
std::shared_ptr<NCCLComm> NCCLComm::split(
NCCLComm* source,
int color_id,
int rank,
ncclConfig_t& config,
std::vector<uint64_t>& ranks_ull) {
TORCH_CHECK(
color_id >= NCCL_SPLIT_NOCOLOR,
"Color must be a non-negative value or NCCL_SPLIT_NOCOLOR (-1)"
", but got ",
color_id);
LOG(INFO) << "Rank " << source->rank_ << ": split from parent comm "
<< source->repr() << " with color_id " << color_id << " and rank "
<< rank;
at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_);
auto comm = std::make_shared<NCCLComm>();
// This call will block until the source communicator is initialized
auto sourceComm = source->getNcclComm();
#ifndef NCCL_HAS_COMM_NONBLOCKING
C10D_NCCL_CHECK(
ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config),
std::nullopt);
#else
// After calling ncclCommSplit in non-blocking mode, we should wait for the
// source communicator to be out of ncclInProgress state.
// Reason 1:
// it's unsafe to call new operations on the parent comm while it's in
// ncclInProgress state.
// Reason 2:
// as of NCCL 2.23, the ptr value of child comm will not be filled until the
// state of parent comm is ncclSuccess. This may change in the future. See:
// https://github.com/NVIDIA/nccl/issues/1472
C10D_NCCL_CHECK_TIMEOUT_SLEEP(
ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config),
sourceComm, // wait on parent comm
std::nullopt);
if (color_id >= 0) {
// Waiting for parent comm above still does not seem to guarantee the child
// comm ptr is valid. Therefore we add a manual wait here for safety.
// TODO: remove this wait after NCCL fix the semantics.
auto startTime = std::chrono::steady_clock::now();
auto timeout = nccl_nonblocking_timeout();
while (!comm->ncclComm_) {
C10D_CHECK_TIMEOUT(startTime, timeout);
C10D_SCHED_SLEEP();
}
}
// comm->ncclComm_ should have valid ptr by now, but not necessarily
// initialized. Rely on getNcclComm() to wait for its initialization.
#endif
++source->ncclCommSplitCounter_;
comm->rank_ = rank;
// Child comm should be on the same device as parent comm
comm->deviceIndex_ = source->deviceIndex_;
comm->nonBlocking_ = config.blocking == 0;
LOG(INFO) << "Rank " << source->rank_ << ": created child comm "
<< comm->repr() << " with color_id " << color_id;
return comm;
}
#endif
void NCCLComm::finalize() {
LockType lock(mutex_);
if (aborted_) {
LOG(INFO) << "Rank " << rank_
<< ": NCCL communicator already Invalidated. Skip finalize.";
return;
}
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
auto comm = getNcclComm();
C10D_NCCL_CHECK_NONBLOCKING(ncclCommFinalize(comm), std::nullopt);
}
void NCCLComm::destroy() {
LockType lock(mutex_);
if (aborted_) {
LOG(INFO) << "Rank " << rank_
<< ": NCCL communicator already Invalidated. Skip destroy.";
return;
}
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
auto comm = getNcclComm();
C10D_NCCL_CHECK(ncclCommDestroy(comm), std::nullopt);
// Poison future getNcclComm
aborted_ = true;
}
std::string getNcclVersion() {
static c10::once_flag ncclGetVersionFlag;
static std::string versionString;
c10::call_once(ncclGetVersionFlag, []() {
int version = 0;
ncclResult_t status = ncclGetVersion(&version);
// can't compute the version if call did not return successfully or version
// code < 100 (corresponding to 0.1.0)
if (status != ncclSuccess || version < 100) {
versionString = "Unknown NCCL version";
} else {
// NCCL changed version coding starting 2.9
const int majorBase = version < 2900 ? 1000 : 10000;
const int minorBase = 100;
auto ncclMajor = version / majorBase;
auto ncclMinor = (version % majorBase) / minorBase;
auto ncclPatch =
version % (ncclMajor * majorBase + ncclMinor * minorBase);
versionString = std::to_string(ncclMajor) + "." +
std::to_string(ncclMinor) + "." + std::to_string(ncclPatch);
#ifdef NCCL_SUFFIX
const auto ncclSuffix = std::string(NCCL_SUFFIX);
if (!ncclSuffix.empty()) {
versionString += "." + ncclSuffix;
}
#endif
}
});
return versionString;
}
#ifdef USE_C10D_NCCL
size_t hashTensors(const std::vector<at::Tensor>& tensors) {
size_t hash = 0;
for (auto& tensor : tensors) {
if (tensor.numel() > 0 && tensor.storage()) {
size_t data_size = tensor.storage().nbytes();
if (data_size > 0 && tensor.storage().data_ptr()) {
auto src = static_cast<const char*>(tensor.storage().data_ptr().get());
std::vector<char> dst(data_size);
// This is needed so that we trigger a device synchronization so we can
// get the collective finished if launched on GPU and hash its output.
cudaMemcpy(dst.data(), src, data_size, cudaMemcpyDeviceToHost);
for (size_t i = 0; i < data_size; ++i) {
// Update the hash for each byte in the tensor
hash = c10::hash_combine(hash, c10::get_hash(dst[i], data_size));
}
}
}
}
return hash;
}
#endif
// Default value: 30 minutes
int nccl_nonblocking_timeout() {
static int timeout = -2; // -2 means not initialized
if (timeout == -2) {
const auto val = c10::utils::get_env("TORCH_NCCL_NONBLOCKING_TIMEOUT");
if (val.has_value() && !val.value().empty()) {
timeout = stoi(val.value());
} else {
// Default value consistent with kBackendDefaultTimeout
timeout = 30 * 60;
}
}
return timeout;
}
std::string ncclGetErrorWithVersion(ncclResult_t error) {
return std::string(ncclGetErrorString(error)) + ", NCCL version " +
getNcclVersion();
}
// Provides additional detail into NCCL error codes based on when these are
// thrown in the NCCL codebase.
std::string getNcclErrorDetailStr(
ncclResult_t error,
std::optional<std::string> processGroupFailureReason /* = std::nullopt */
) {
// Prioritize failure reason provided by PG NCCL first, as it can abort
// communicators when it encounters collective timeouts, etc.
if (processGroupFailureReason != std::nullopt) {
return *processGroupFailureReason;
}
std::string interpret;
std::string err;
#ifdef ENABLE_NCCL_GET_LAST_ERROR
auto ret = ncclGetLastError(nullptr);
if (ret) {
err = "\nLast error:\n" + std::string(ret);
} else {
err = "\nLast error: Unknown NCCL Error\n";
}
#endif
switch (error) {
case ncclUnhandledCudaError:
interpret = "ncclUnhandledCudaError: Call to CUDA function failed.";
break;
case ncclSystemError:
interpret =
"ncclSystemError: System call (e.g. socket, malloc) or external library call failed or device error. ";
#ifndef NCCL_REMOTE_ERROR
// Before ncclRemoteError was created, unexpected remote disconnect was
// categorized as ncclSystemError
interpret += "It can be also caused by unexpected exit of a remote peer.";
#endif
break;
case ncclInternalError:
interpret = "ncclInternalError: Internal check failed.";
break;
case ncclInvalidArgument:
interpret = "ncclInvalidArgument: Invalid value for an argument.";
break;
case ncclInvalidUsage:
interpret =
"ncclInvalidUsage: This usually reflects invalid usage of NCCL library.";
break;
#ifdef NCCL_REMOTE_ERROR
case ncclRemoteError:
interpret =
"ncclRemoteError: A call failed possibly due to a network error or a remote process exiting prematurely.";
break;
#endif
default:
interpret = "Unknown NCCL error!";
}
return interpret + err;
}
} // namespace c10d
#endif // USE_C10D_NCCL
|