File: intra_node_comm.cu

package info (click to toggle)
pytorch-cuda 2.6.0%2Bdfsg-7
  • links: PTS, VCS
  • area: contrib
  • in suites: forky, sid, trixie
  • size: 161,620 kB
  • sloc: python: 1,278,832; cpp: 900,322; ansic: 82,710; asm: 7,754; java: 3,363; sh: 2,811; javascript: 2,443; makefile: 597; ruby: 195; xml: 84; objc: 68
file content (124 lines) | stat: -rw-r--r-- 3,900 bytes parent folder | download | duplicates (3)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <torch/csrc/distributed/c10d/intra_node_comm.hpp>

#include <torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h>

namespace c10d {
namespace intra_node_comm {

static constexpr size_t kOneShotThreshBytes = 256 * 1024;
static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024;

static void checkInput(const at::Tensor& input, int deviceIdx) {
  TORCH_CHECK(
      input.dtype() == at::kBFloat16 || input.dtype() == at::kFloat,
      "oneShotAllReduce only supports float and bf16 for now");
  TORCH_CHECK(input.is_non_overlapping_and_dense());
  TORCH_CHECK(input.device().is_cuda());
  TORCH_CHECK(
      input.get_device() == deviceIdx,
      "IntraNodeComm: expect input to be on device ",
      deviceIdx,
      ", got device ",
      input.get_device());
}

bool isIntraNodeCommSupported() {
#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800))
  return false;
#else
  return true;
#endif
}

at::Tensor IntraNodeComm::oneShotAllReduce(
    const at::Tensor& input,
    at::cuda::CUDAStream& stream) {
  checkInput(input, deviceIdx_);

  auto op = c10::Dispatcher::singleton()
                .findSchemaOrThrow("symm_mem::one_shot_all_reduce_out", "")
                .typed<at::Tensor(
                    const at::Tensor&, std::string, std::string, at::Tensor)>();

  auto symmMemTensor = at::from_blob(
      symmetricMemoryPtr_,
      input.sizes(),
      at::TensorOptions().dtype(input.dtype()).device(input.device()));

  symmMemTensor.copy_(input);
  op.call(symmMemTensor, "sum", "", input);
  return input;
}

at::Tensor IntraNodeComm::twoShotAllReduce(
    const at::Tensor& input,
    at::cuda::CUDAStream& stream) {
  checkInput(input, deviceIdx_);

  auto op = c10::Dispatcher::singleton()
                .findSchemaOrThrow("symm_mem::two_shot_all_reduce_", "")
                .typed<at::Tensor(at::Tensor, std::string, std::string)>();

  auto symmMemTensor = at::from_blob(
      symmetricMemoryPtr_,
      input.sizes(),
      at::TensorOptions().dtype(input.dtype()).device(input.device()));

  symmMemTensor.copy_(input);
  op.call(symmMemTensor, "sum", "");
  input.copy_(symmMemTensor);
  return input;
}

AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) {
  // Only support float and bf16 for now
  if (input.dtype() != at::kBFloat16 && input.dtype() != at::kFloat) {
    return AllReduceAlgo::NONE;
  }
  const auto inputSize =
      static_cast<size_t>(input.numel() * input.element_size());
  const size_t ptrAlignment = get_alignment(
      static_cast<size_t>(input.storage_offset() * input.element_size()));
  const size_t sizeAlignment = get_alignment(inputSize);
  const size_t alignment = std::min(ptrAlignment, sizeAlignment);

  if (topology_ == Topology::FULLY_CONNECTED) {
    // Both symm_mem::one_shot_all_reduce and symm_mem::two_shot_all_reduce_
    // currently requires the input to be at least 4-bytes aligned.
    if (alignment >= 4 && inputSize <= kOneShotThreshBytes &&
        inputSize <= bufferSize_) {
      return AllReduceAlgo::ONE_SHOT;
    }
    if (alignment >= 4 && inputSize <= kTwoShotThreshBytes &&
        inputSize <= bufferSize_) {
      return AllReduceAlgo::TWO_SHOT;
    }
  }
  return AllReduceAlgo::NONE;
}

static int64_t usageCounter = 0;

at::Tensor IntraNodeComm::allReduce(
    const at::Tensor& input,
    AllReduceAlgo algo) {
  // Report usage for testing purposes.
  // We don't care about overflowing.
  ++usageCounter;
  auto stream = at::cuda::getCurrentCUDAStream();
  switch (algo) {
    case AllReduceAlgo::ONE_SHOT:
      return oneShotAllReduce(input, stream);
    case AllReduceAlgo::TWO_SHOT:
      return twoShotAllReduce(input, stream);
    default:
      C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo");
  }
}

int64_t getIntraNodeCommUsageCounter() {
  return usageCounter;
}

} // namespace intra_node_comm
} // namespace c10d