File: test_allreduce.cu

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (143 lines) | stat: -rw-r--r-- 4,785 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#if defined(XGBOOST_USE_NCCL)
#include <gtest/gtest.h>
#include <thrust/host_vector.h>  // for host_vector

#include "../../../src/collective/comm.cuh"        // for NCCLComm
#include "../../../src/common/cuda_rt_utils.h"     // for AllVisibleGPUs
#include "../../../src/common/device_helpers.cuh"  // for ToSpan,  device_vector
#include "../../../src/common/type.h"              // for EraseType
#include "test_worker.cuh"                         // for NCCLWorkerForTest
#include "test_worker.h"                           // for WorkerForTest, TestDistributed

namespace xgboost::collective {
namespace {
class MGPUAllreduceTest : public SocketTest {};

class Worker : public NCCLWorkerForTest {
 public:
  using NCCLWorkerForTest::NCCLWorkerForTest;

  bool SkipIfOld() {
    auto nccl = dynamic_cast<NCCLComm const*>(nccl_comm_.get());
    std::int32_t major = 0, minor = 0, patch = 0;
    SafeColl(nccl->Stub()->GetVersion(&major, &minor, &patch));
    CHECK_GE(major, 2);
    bool too_old = minor < 23;
    if (too_old) {
      LOG(INFO) << "NCCL compile version:" << NCCL_VERSION_CODE << " runtime version:" << major
                << "." << minor << "." << patch;
    }
    return too_old;
  }

  void BitOr() {
    dh::device_vector<std::uint32_t> data(comm_.World(), 0);
    data[comm_.Rank()] = ~std::uint32_t{0};
    auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
                                    ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
    SafeColl(rc);
    thrust::host_vector<std::uint32_t> h_data(data.size());
    thrust::copy(data.cbegin(), data.cend(), h_data.begin());
    for (auto v : h_data) {
      ASSERT_EQ(v, ~std::uint32_t{0});
    }
  }

  void Acc() {
    dh::device_vector<double> data(314, 1.5);
    auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
                                    ArrayInterfaceHandler::kF8, Op::kSum);
    SafeColl(rc);
    for (std::size_t i = 0; i < data.size(); ++i) {
      auto v = data[i];
      ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
    }
  }

  Result NoCheck() {
    dh::device_vector<double> data(314, 1.5);
    auto rc = nccl_coll_->Allreduce(*nccl_comm_, common::EraseType(dh::ToSpan(data)),
                                    ArrayInterfaceHandler::kF8, Op::kSum);
    return rc;
  }

  ~Worker() noexcept(false) override = default;
};
}  // namespace

TEST_F(MGPUAllreduceTest, BitOr) {
  auto n_workers = curt::AllVisibleGPUs();
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    Worker w{host, port, timeout, n_workers, r};
    w.Setup();
    w.BitOr();
  });
}

TEST_F(MGPUAllreduceTest, Sum) {
  auto n_workers = curt::AllVisibleGPUs();
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    Worker w{host, port, timeout, n_workers, r};
    w.Setup();
    w.Acc();
  });
}

TEST_F(MGPUAllreduceTest, Timeout) {
  auto n_workers = curt::AllVisibleGPUs();
  if (n_workers <= 1) {
    GTEST_SKIP_("Requires more than one GPU to run.");
  }
  using std::chrono_literals::operator""s;

  TestDistributed(
      n_workers,
      [=](std::string host, std::int32_t port, std::chrono::seconds, std::int32_t r) {
        auto w = std::make_unique<Worker>(host, port, 1s, n_workers, r);
        w->Setup();
        if (w->SkipIfOld()) {
          GTEST_SKIP_("nccl is too old.");
          return;
        }
        // 1s for worker timeout, sleeping for 2s should trigger a timeout error.
        if (r == 0) {
          std::this_thread::sleep_for(2s);
        }
        auto rc = w->NoCheck();
        if (r == 1) {
          auto rep = rc.Report();
          ASSERT_NE(rep.find("NCCL timeout:"), std::string::npos) << rep;
        }

        w.reset();
      },
      // We use 8s for the tracker to make sure shutdown is successful.
      8s);

  TestDistributed(
      n_workers,
      [=](std::string host, std::int32_t port, std::chrono::seconds, std::int32_t r) {
        auto w = std::make_unique<Worker>(host, port, 1s, n_workers, r);
        w->Setup();
        if (w->SkipIfOld()) {
          GTEST_SKIP_("nccl is too old.");
          return;
        }
        // Only one of the workers is doing allreduce.
        if (r == 0) {
          auto rc = w->NoCheck();
          ASSERT_NE(rc.Report().find("NCCL timeout:"), std::string::npos) << rc.Report();
        }

        w.reset();
      },
      // We use 8s for the tracker to make sure shutdown is successful.
      8s);
}
}  // namespace xgboost::collective
#endif  // defined(XGBOOST_USE_NCCL)