1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
|
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>
#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>
namespace c10d {
namespace intra_node_comm {
static constexpr size_t kOneShotThreshBytes = 256 * 1024;
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;
static void checkInput(const at::Tensor& input, int deviceIdx) {
TORCH_CHECK(
input.dtype() == at::kBFloat16 || input.dtype() == at::kFloat,
"oneShotAllReduce only supports float and bf16 for now");
TORCH_CHECK(input.is_non_overlapping_and_dense());
TORCH_CHECK(input.device().is_cuda());
TORCH_CHECK(
input.get_device() == deviceIdx,
"IntraNodeComm: expect input to be on device ",
deviceIdx,
", got device ",
input.get_device());
}
bool isIntraNodeCommSupported() {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
return false;
#else
return true;
#endif
}
at::Tensor IntraNodeComm::oneShotAllReduce(
const at::Tensor& input,
at::cuda::CUDAStream& stream) {
checkInput(input, deviceIdx_);
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("symm_mem::one_shot_all_reduce_out", "")
.typed<at::Tensor(
const at::Tensor&, std::string, std::string, at::Tensor)>();
auto symmMemTensor = at::from_blob(
symmetricMemoryPtr_,
input.sizes(),
at::TensorOptions().dtype(input.dtype()).device(input.device()));
symmMemTensor.copy_(input);
op.call(symmMemTensor, "sum", "", input);
return input;
}
at::Tensor IntraNodeComm::twoShotAllReduce(
const at::Tensor& input,
at::cuda::CUDAStream& stream) {
checkInput(input, deviceIdx_);
auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("symm_mem::two_shot_all_reduce_", "")
.typed<at::Tensor(at::Tensor, std::string, std::string)>();
auto symmMemTensor = at::from_blob(
symmetricMemoryPtr_,
input.sizes(),
at::TensorOptions().dtype(input.dtype()).device(input.device()));
symmMemTensor.copy_(input);
op.call(symmMemTensor, "sum", "");
input.copy_(symmMemTensor);
return input;
}
AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
// Only support float and bf16 for now
if (input.dtype() != at::kBFloat16 && input.dtype() != at::kFloat) {
return AllReduceAlgo::NONE;
}
const auto inputSize =
static_cast<size_t>(input.numel() * input.element_size());
const size_t ptrAlignment = get_alignment(
static_cast<size_t>(input.storage_offset() * input.element_size()));
const size_t sizeAlignment = get_alignment(inputSize);
const size_t alignment = std::min(ptrAlignment, sizeAlignment);
if (topology_ == Topology::FULLY_CONNECTED) {
// Both symm_mem::one_shot_all_reduce and symm_mem::two_shot_all_reduce_
// currently requires the input to be at least 4-bytes aligned.
if (alignment >= 4 && inputSize <= kOneShotThreshBytes &&
inputSize <= bufferSize_) {
return AllReduceAlgo::ONE_SHOT;
}
if (alignment >= 4 && inputSize <= kTwoShotThreshBytes &&
inputSize <= bufferSize_) {
return AllReduceAlgo::TWO_SHOT;
}
}
return AllReduceAlgo::NONE;
}
static int64_t usageCounter = 0;
at::Tensor IntraNodeComm::allReduce(
const at::Tensor& input,
AllReduceAlgo algo) {
// Report usage for testing purposes.
// We don't care about overflowing.
++usageCounter;
auto stream = at::cuda::getCurrentCUDAStream();
switch (algo) {
case AllReduceAlgo::ONE_SHOT:
return oneShotAllReduce(input, stream);
case AllReduceAlgo::TWO_SHOT:
return twoShotAllReduce(input, stream);
default:
C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
}
}
int64_t getIntraNodeCommUsageCounter() {
return usageCounter;
}
} // namespace intra_node_comm
} // namespace c10d
|