1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918
|
#include <ATen/core/functional.h>
#include <torch/csrc/cuda/device_set.h>
#include <torch/csrc/cuda/nccl.h>
#include <ATen/ATen.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/util/Exception.h>
#include <c10/util/hash.h>
#include <c10/util/irange.h>
#include <nccl.h>
#include <limits>
#include <sstream>
#include <type_traits>
#include <unordered_map>
ncclComm_t* to_nccl_comm(torch::cuda::nccl::ncclComm_t* var) {
return reinterpret_cast<ncclComm_t*>(var);
}
ncclComm_t to_nccl_comm(torch::cuda::nccl::ncclComm_t var) {
return reinterpret_cast<ncclComm_t>(var);
}
ncclUniqueId* to_nccl_unique_id(torch::cuda::nccl::ncclUniqueId* var) {
return reinterpret_cast<ncclUniqueId*>(var);
}
ncclResult_t to_nccl_result(torch::cuda::nccl::ncclResult var) {
switch (var) {
case torch::cuda::nccl::ncclResult::Success:
return ncclResult_t::ncclSuccess;
case torch::cuda::nccl::ncclResult::UnhandledCudaError:
return ncclResult_t::ncclUnhandledCudaError;
case torch::cuda::nccl::ncclResult::SystemError:
return ncclResult_t::ncclSystemError;
case torch::cuda::nccl::ncclResult::InternalError:
return ncclResult_t::ncclInternalError;
case torch::cuda::nccl::ncclResult::InvalidArgument:
return ncclResult_t::ncclInvalidArgument;
case torch::cuda::nccl::ncclResult::InvalidUsage:
return ncclResult_t::ncclInvalidUsage;
case torch::cuda::nccl::ncclResult::NumResults:
return ncclResult_t::ncclNumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
}
}
torch::cuda::nccl::ncclResult from_nccl_result(ncclResult_t var) {
switch (var) {
case ncclSuccess:
return torch::cuda::nccl::ncclResult::Success;
case ncclUnhandledCudaError:
return torch::cuda::nccl::ncclResult::UnhandledCudaError;
case ncclSystemError:
return torch::cuda::nccl::ncclResult::SystemError;
case ncclInternalError:
return torch::cuda::nccl::ncclResult::InternalError;
case ncclInvalidArgument:
return torch::cuda::nccl::ncclResult::InvalidArgument;
case ncclInvalidUsage:
return torch::cuda::nccl::ncclResult::InvalidUsage;
case ncclNumResults:
return torch::cuda::nccl::ncclResult::NumResults;
default:
throw std::runtime_error("Unconvertible NCCL type");
}
}
ncclDataType_t to_nccl_data_type(c10::ScalarType type) {
switch (type) {
case at::kFloat:
return ncclDataType_t::ncclFloat;
case at::kHalf:
return ncclDataType_t::ncclHalf;
case at::kDouble:
return ncclDataType_t::ncclDouble;
case at::kLong:
return ncclDataType_t::ncclInt64;
case at::kInt:
return ncclDataType_t::ncclInt;
case at::kChar:
return ncclDataType_t::ncclChar;
case at::kByte:
return ncclDataType_t::ncclUint8;
case at::kBool:
return ncclDataType_t::ncclUint8;
#if HAS_NCCL_BF16_DATATYPE
case at::kBFloat16:
return ncclDataType_t::ncclBfloat16;
#endif
default:
TORCH_CHECK(false, "Unconvertible NCCL type ", type);
}
}
ncclDataType_t to_nccl_data_type(const at::Tensor& t) {
if (!t.is_cuda()) {
TORCH_CHECK(
false,
"NCCL only supports CUDA tensors, but got a tensor on ",
t.device());
}
return to_nccl_data_type(t.scalar_type());
}
ncclRedOp_t to_nccl_red_op(int var) {
return (ncclRedOp_t)(var);
}
namespace torch {
namespace cuda {
namespace nccl {
using namespace at;
namespace detail {
static inline void NCCL_CHECK(ncclResult_t result) {
NCCL_CHECK(from_nccl_result(result));
}
void throw_nccl_error(torch::cuda::nccl::ncclResult status) {
std::ostringstream err;
err << "NCCL Error " << static_cast<int>(status) << ": "
<< ncclGetErrorString(to_nccl_result(status));
throw std::runtime_error(err.str());
}
struct NcclCommList {
std::unique_ptr<ncclComm_t[]> comms;
int ndevices;
NcclCommList(const std::vector<int>& devices)
: comms(new ncclComm_t[devices.size()]), ndevices(devices.size()) {
NCCL_CHECK(ncclCommInitAll(
to_nccl_comm(comms.get()), devices.size(), devices.data()));
}
NcclCommList(NcclCommList&& foo) = default;
~NcclCommList() {
if (comms) {
for (const auto i : c10::irange(ndevices)) {
int dummy_var;
if (cudaGetDevice(&dummy_var) != cudaSuccess) {
/* there are cases when this destructor is called after the
CUDA driver is already unloaded from the process.
In these cases, skip ncclCommDestroy */
return;
}
comm_destroy(comms[i]);
}
}
}
ArrayRef<ncclComm_t> ref() const {
return ArrayRef<ncclComm_t>(comms.get(), ndevices);
}
};
using device_list = std::vector<int>;
// accesses to this object have to be guarded by THC's CudaFreeMutex
static std::unordered_map<device_list, NcclCommList, c10::hash<device_list>>
_communicators;
ArrayRef<ncclComm_t> get_communicators(TensorList inputs) {
static auto get_device = [](const at::Tensor& t) -> int {
return t.get_device();
};
device_list devices = fmap(inputs, get_device);
auto it = _communicators.find(devices);
if (it == _communicators.end())
std::tie(it, std::ignore) = _communicators.emplace(devices, devices);
return it->second.ref();
}
static inline void check_tensor(
const at::Tensor& input,
const at::optional<at::Tensor>& output,
int input_multiplier,
int output_multiplier,
int64_t ref_numel,
ScalarType ref_dtype) {
auto check_one = [&](const at::Tensor& tensor) {
if (!tensor.is_cuda() || tensor.is_sparse()) {
throw std::runtime_error(
"input and output elements have to be cuda dense Tensors");
}
if (ref_dtype != tensor.scalar_type()) {
throw std::runtime_error(
"all inputs and outputs must be of the same Tensor dtype");
}
if (!tensor.is_contiguous()) {
throw std::runtime_error("all inputs and outputs have to be contiguous");
}
};
check_one(input);
// all inputs must be same size
if (input.numel() != ref_numel) {
throw std::runtime_error(
"all inputs must have the same number of elements");
}
if (output) {
check_one(*output);
// inputs and outputs must be on same device respectively
if (input.get_device() != output->get_device()) {
throw std::runtime_error("input and output must be on the same device");
}
if (output->numel() * output_multiplier != ref_numel * input_multiplier) {
throw std::runtime_error(
"output must be of size input_size * size_multiplier");
}
}
}
void check_inputs(
TensorList inputs,
TensorList outputs,
int input_multiplier,
int output_multiplier) {
// len(inputs) == len(outputs)
size_t len = inputs.size();
if (len <= 0) {
throw std::runtime_error("input sequence can't be empty");
}
if (len != outputs.size()) {
std::stringstream err;
err << "inputs and outputs sequences have to be of the same length, but got input of length "
<< len << " and output of length " << outputs.size();
throw std::runtime_error(err.str());
}
device_set devices;
int64_t numel = inputs[0].numel();
auto dtype = inputs[0].scalar_type();
for (const auto i : c10::irange(len)) {
auto input = inputs[i];
auto output = outputs[i];
check_tensor(
input, output, input_multiplier, output_multiplier, numel, dtype);
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
devices.set(input_device);
}
}
void check_inputs(
TensorList inputs,
const at::Tensor& output,
int root,
int input_multiplier,
int output_multiplier) {
size_t len = inputs.size();
if (len <= 0) {
throw std::runtime_error("input sequence can't be empty");
}
device_set devices;
int64_t numel = inputs[0].numel();
auto dtype = inputs[0].scalar_type();
for (const auto i : c10::irange(len)) {
auto input = inputs[i];
check_tensor(
input,
i == root ? at::optional<at::Tensor>{output} : at::nullopt,
input_multiplier,
output_multiplier,
numel,
dtype);
auto input_device = input.get_device();
// inputs must be on unique devices
if (devices.test(input_device)) {
throw std::runtime_error("inputs must be on unique devices");
}
devices.set(input_device);
}
}
} // namespace detail
AutoNcclGroup::AutoNcclGroup() {
(c10::cuda::CUDACachingAllocator::getFreeMutex())->lock();
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupStart());
#endif
}
AutoNcclGroup::~AutoNcclGroup() noexcept(false) {
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
detail::NCCL_CHECK(ncclGroupEnd());
#endif
(c10::cuda::CUDACachingAllocator::getFreeMutex())->unlock();
}
bool is_available(TensorList tensors) {
#ifdef USE_NCCL
device_set devices;
for (auto& tensor : tensors) {
if (!tensor.is_cuda() || tensor.is_sparse())
return false;
if (!tensor.is_contiguous())
return false;
auto device = tensor.get_device();
if (devices[device])
return false;
devices[device] = true;
}
return true;
#else
return false;
#endif
}
std::uint64_t version() {
#if defined(NCCL_MAJOR)
constexpr std::uint64_t ver = (((uint64_t)NCCL_MAJOR) << 32) |
(((uint64_t)NCCL_MINOR) << 16) | ((uint64_t)NCCL_PATCH);
return ver;
#elif defined(USE_NCCL)
// return major version "1"
return ((uint64_t)1) << 32;
#else
return 0;
#endif
}
void get_unique_id(ncclUniqueId& id) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id)));
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
ncclComm_t comm;
ncclUniqueId id = comm_id;
NCCL_CHECK(ncclCommInitRank(
to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank));
return comm;
#else
return nullptr;
#endif
}
void comm_destroy(ncclComm_t comm) {
/*
* TODO(T30279827) Temporarily disable calling ncclCommDestroy
* Calling ncclCommDestroy while program exiting is undefined
* according to Nvidia, and lead to segfault in NCCL 2
* (whether it is called before or after the CUDA runtime destructor).
* Temporarily disable it in destructor to avoid segfault.
* Following up with Nvidia for long term solution.
*/
return;
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclCommDestroy(to_nccl_comm(comm)));
#endif
}
namespace {
// NCCL changed the numerical type used for count between NCCL1 and NCCL2.
// So we use the following struct, which gets the type of the second argument
// of T, if T is a function type, with ncclBcast, to get that type statically
// and programmatically.
template <typename T>
struct GetSecondArgType;
template <typename R, typename Arg0, typename Arg1, typename... Args>
struct GetSecondArgType<R(Arg0, Arg1, Args...)> {
typedef typename std::decay<Arg1>::type type;
};
constexpr auto count_max =
std::numeric_limits<GetSecondArgType<decltype(ncclBcast)>::type>::max();
} // namespace
size_t get_max_count() {
return count_max;
}
void broadcast(
TensorList tensors,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
check_inputs(tensors, tensors, 1, 1);
auto data_type = to_nccl_data_type(tensors[0]);
int64_t numel = tensors[0].numel();
const auto comms = user_comms.empty() ? get_communicators(tensors)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (size_t i = 0, num_tensors = tensors.size(); i < num_tensors; i++) {
int device = tensors[i].get_device();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
TORCH_CHECK(
static_cast<uint64_t>(numel) <= static_cast<uint64_t>(count_max),
"Broadcast tensor has ",
numel,
" elements, which exceeds the "
"maximum NCCL supports (",
count_max,
")");
ncclComm_t comm = comms[i];
NCCL_CHECK(ncclBcast(
tensors[i].data_ptr(),
numel,
data_type,
0,
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce(
const std::vector<at::Tensor>& inputs,
at::Tensor& output,
int32_t root,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
TORCH_CHECK(
root >= 0 && static_cast<size_t>(root) < inputs.size(), "invalid root");
check_inputs(inputs, output, root, 1, 1);
const auto len = inputs.size();
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduce(
inputs[i].data_ptr(),
root == i ? output.data_ptr() : nullptr,
count,
data_type,
to_nccl_red_op(op),
root,
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce(
std::vector<at::Tensor>& inputs,
int32_t root,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
reduce(inputs, /*output=*/inputs[root], root, op, streams, user_comms);
}
void all_reduce(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
check_inputs(inputs, outputs, 1, 1);
const auto len = inputs.size();
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclAllReduce(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void reduce_scatter(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
int32_t op,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
const auto len = inputs.size();
check_inputs(inputs, outputs, 1, len);
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel() / len;
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
NCCL_CHECK(ncclReduceScatter(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_red_op(op),
to_nccl_comm(comm),
stream));
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all_gather(
const std::vector<at::Tensor>& inputs,
std::vector<at::Tensor>& outputs,
const stream_list& streams,
const comm_list& user_comms) {
#ifdef USE_NCCL
using namespace torch::cuda::nccl::detail;
const auto len = inputs.size();
check_inputs(inputs, outputs, len, 1);
auto data_type = to_nccl_data_type(inputs[0]);
const auto count = inputs[0].numel();
auto comms_ref = user_comms.empty() ? get_communicators(inputs)
: ArrayRef<ncclComm_t>(user_comms);
AutoNcclGroup nccl_group_guard;
at::cuda::OptionalCUDAGuard device_guard;
for (const auto i : c10::irange(len)) {
int device = inputs[i].device().index();
device_guard.set_index(device);
// Default to the current stream
const auto stream = (streams.empty() || !streams[i])
? at::cuda::getCurrentCUDAStream(device).stream()
: streams[i]->stream();
ncclComm_t comm = comms_ref[i];
#if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2)
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
outputs[i].data_ptr(),
count,
data_type,
to_nccl_comm(comm),
stream));
#else
NCCL_CHECK(ncclAllGather(
inputs[i].data_ptr(),
count,
data_type,
outputs[i].data_ptr(),
to_nccl_comm(comm),
stream));
#endif
}
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all_single_equal_split(
at::Tensor& input,
at::Tensor& output,
int size,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
int numranks;
auto type = to_nccl_data_type(input);
size_t count = input.numel() / size;
size_t rankdiff = input.nbytes() / size;
const auto* sendbuff = reinterpret_cast<char*>(input.data_ptr());
auto* recvbuff = reinterpret_cast<char*>(output.data_ptr());
auto comm = to_nccl_comm(_comm);
#if defined(USE_ROCM) && ROCM_VERSION >= 50000
NCCL_CHECK(ncclAllToAll(sendbuff, recvbuff, count, type, comm, stream));
#else
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(numranks)) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (count != 0) {
NCCL_CHECK(
ncclSend(sendbuff + r * rankdiff, count, type, r, comm, stream));
NCCL_CHECK(
ncclRecv(recvbuff + r * rankdiff, count, type, r, comm, stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#endif
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all_single_unequal_split(
void* sendbuff,
const size_t* sendcounts,
const size_t* senddispls,
void* recvbuff,
const size_t* recvcounts,
const size_t* recvdispls,
size_t size,
c10::ScalarType _type,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto type = to_nccl_data_type(_type);
auto comm = to_nccl_comm(_comm);
int numranks;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(numranks)) {
// NCCL uses 0 byte message for synchronization
// Avoid send/recv when message size is zero
if (sendcounts[r] != 0) {
NCCL_CHECK(ncclSend(
((char*)sendbuff) + senddispls[r] * size,
sendcounts[r],
type,
r,
comm,
stream));
}
if (recvcounts[r] != 0) {
NCCL_CHECK(ncclRecv(
((char*)recvbuff) + recvdispls[r] * size,
recvcounts[r],
type,
r,
comm,
stream));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void all2all(
std::vector<at::Tensor>& outputTensors,
std::vector<at::Tensor>& inputTensors,
ncclComm_t _comm,
at::cuda::CUDAStream& stream) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
NCCL_CHECK(ncclGroupStart());
for (const auto r : c10::irange(outputTensors.size())) {
at::Tensor& input = inputTensors[r];
at::Tensor& output = outputTensors[r];
if (input.numel() != 0) {
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
r,
comm,
stream.stream()));
}
if (output.numel() != 0) {
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
r,
comm,
stream.stream()));
}
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void send(
const at::Tensor& input,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int dst) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclSend(
input.data_ptr(),
input.numel(),
to_nccl_data_type(input),
dst,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void recv(
at::Tensor& output,
ncclComm_t comm,
at::cuda::CUDAStream stream,
int src) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \
(NCCL_MINOR >= 7)
using namespace torch::cuda::nccl::detail;
NCCL_CHECK(ncclRecv(
output.data_ptr(),
output.numel(),
to_nccl_data_type(output),
src,
to_nccl_comm(comm),
stream.stream()));
#else
AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void gather(
const at::Tensor& inputs,
std::vector<at::Tensor>& outputs,
ncclComm_t _comm,
at::cuda::CUDAStream& stream,
int32_t root) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
size_t count = inputs.numel();
auto type = to_nccl_data_type(inputs);
const auto* sendbuff = reinterpret_cast<char*>(inputs.data_ptr());
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
if (r != root) {
auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
} else {
// on its own rank, simply copy from the input
outputs[r].copy_(inputs);
}
}
} else {
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
void scatter(
const std::vector<at::Tensor>& inputs,
at::Tensor& outputs,
ncclComm_t _comm,
at::cuda::CUDAStream& stream,
int32_t root) {
#ifdef USE_NCCL
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && \
(NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
using namespace torch::cuda::nccl::detail;
auto comm = to_nccl_comm(_comm);
int numranks, cur_rank;
NCCL_CHECK(ncclCommCount(comm, &numranks));
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
NCCL_CHECK(ncclGroupStart());
if (cur_rank == root) {
for (const auto r : c10::irange(numranks)) {
if (r != root) {
size_t send_count = inputs[r].numel();
auto send_type = to_nccl_data_type(inputs[r]);
const auto* sendbuff = reinterpret_cast<char*>(inputs[r].data_ptr());
NCCL_CHECK(ncclSend(sendbuff, send_count, send_type, r, comm, stream));
} else {
// on its own rank, simply copy it to the output
outputs.copy_(inputs[r]);
}
}
} else {
size_t recv_count = outputs.numel();
auto recv_type = to_nccl_data_type(outputs);
auto* recvbuff = reinterpret_cast<char*>(outputs.data_ptr());
NCCL_CHECK(ncclRecv(recvbuff, recv_count, recv_type, root, comm, stream));
}
NCCL_CHECK(ncclGroupEnd());
#else
AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0");
#endif
#else
AT_ERROR("PyTorch built without NCCL support");
#endif
}
} // namespace nccl
} // namespace cuda
} // namespace torch
|