File: test_allgather.cu

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (117 lines) | stat: -rw-r--r-- 4,506 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/device_vector.h>  // for device_vector
#include <thrust/equal.h>          // for equal
#include <xgboost/span.h>          // for Span

#include <cstddef>  // for size_t
#include <cstdint>  // for int32_t, int64_t
#include <vector>   // for vector

#include "../../../src/collective/allgather.h"     // for RingAllgather
#include "../../../src/common/device_helpers.cuh"  // for ToSpan,  device_vector
#include "../../../src/common/type.h"              // for EraseType
#include "test_worker.cuh"                         // for NCCLWorkerForTest
#include "test_worker.h"                           // for TestDistributed, WorkerForTest

namespace xgboost::collective {
namespace {
class Worker : public NCCLWorkerForTest {
 public:
  using NCCLWorkerForTest::NCCLWorkerForTest;

  void TestV(AllgatherVAlgo algo) {
    {
      // basic test
      std::size_t n = 1;
      // create data
      dh::device_vector<std::int32_t> data(n, comm_.Rank());
      auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
      // get size
      std::vector<std::int64_t> sizes(comm_.World(), -1);
      sizes[comm_.Rank()] = s_data.size_bytes();
      auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
      SafeColl(rc);
      // create result
      dh::device_vector<std::int32_t> result(comm_.World(), -1);
      auto s_result = common::EraseType(dh::ToSpan(result));

      std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
      rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
                                  common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
      SafeColl(rc);

      for (std::int32_t i = 0; i < comm_.World(); ++i) {
        ASSERT_EQ(result[i], i);
      }
    }
    {
      // V test
      std::size_t n = 256 * 256;
      // create data
      dh::device_vector<std::int32_t> data(n * nccl_comm_->Rank(), nccl_comm_->Rank());
      auto s_data = common::EraseType(common::Span{data.data().get(), data.size()});
      // get size
      std::vector<std::int64_t> sizes(nccl_comm_->World(), 0);
      sizes[comm_.Rank()] = dh::ToSpan(data).size_bytes();
      auto rc = RingAllgather(comm_, common::Span{sizes.data(), sizes.size()});
      SafeColl(rc);
      auto n_bytes = std::accumulate(sizes.cbegin(), sizes.cend(), 0);
      // create result
      dh::device_vector<std::int32_t> result(n_bytes / sizeof(std::int32_t), -1);
      auto s_result = common::EraseType(dh::ToSpan(result));

      std::vector<std::int64_t> recv_seg(nccl_comm_->World() + 1, 0);
      rc = nccl_coll_->AllgatherV(*nccl_comm_, s_data, common::Span{sizes.data(), sizes.size()},
                                  common::Span{recv_seg.data(), recv_seg.size()}, s_result, algo);
      SafeColl(rc);
      // check segment size
      if (algo != AllgatherVAlgo::kBcast) {
        auto size = recv_seg[nccl_comm_->Rank() + 1] - recv_seg[nccl_comm_->Rank()];
        ASSERT_EQ(size, n * nccl_comm_->Rank() * sizeof(std::int32_t));
        ASSERT_EQ(size, sizes[nccl_comm_->Rank()]);
      }
      // check data
      std::size_t k{0};
      for (std::int32_t r = 0; r < nccl_comm_->World(); ++r) {
        std::size_t s = n * r;
        auto current = dh::ToSpan(result).subspan(k, s);
        std::vector<std::int32_t> h_data(current.size());
        dh::CopyDeviceSpanToVector(&h_data, current);
        for (auto v : h_data) {
          ASSERT_EQ(v, r);
        }
        k += s;
      }
    }
  }
};

class MGPUAllgatherTest : public SocketTest {};
}  // namespace

TEST_F(MGPUAllgatherTest, MGPUTestVRing) {
  auto n_workers = curt::AllVisibleGPUs();
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    Worker w{host, port, timeout, n_workers, r};
    w.Setup();
    w.TestV(AllgatherVAlgo::kRing);
    w.TestV(AllgatherVAlgo::kBcast);
  });
}

TEST_F(MGPUAllgatherTest, MGPUTestVBcast) {
  auto n_workers = curt::AllVisibleGPUs();
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    Worker w{host, port, timeout, n_workers, r};
    w.Setup();
    w.TestV(AllgatherVAlgo::kBcast);
  });
}
}  // namespace xgboost::collective
#endif  // defined(XGBOOST_USE_NCCL)