File: test_allreduce.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (155 lines) | stat: -rw-r--r-- 4,891 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>

#include <numeric>  // for iota

#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h"  // for Coll
#include "../../../src/common/type.h"  // for EraseType
#include "test_worker.h"               // for WorkerForTest, TestDistributed

namespace xgboost::collective {
namespace {
class AllreduceWorker : public WorkerForTest {
 public:
  using WorkerForTest::WorkerForTest;

  void Basic() {
    {
      std::vector<double> data(13, 0.0);
      auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
        for (std::size_t i = 0; i < rhs.size(); ++i) {
          rhs[i] += lhs[i];
        }
      });
      SafeColl(rc);
      ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
    }
    {
      std::vector<double> data(1, 1.0);
      auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
        for (std::size_t i = 0; i < rhs.size(); ++i) {
          rhs[i] += lhs[i];
        }
      });
      SafeColl(rc);
      ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
    }
  }

  void Restricted() {
    this->LimitSockBuf(4096);

    std::size_t n = 4096 * 4;
    std::vector<std::int32_t> data(comm_.World() * n, 1);
    auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
      for (std::size_t i = 0; i < rhs.size(); ++i) {
        rhs[i] += lhs[i];
      }
    });
    SafeColl(rc);
    for (auto v : data) {
      ASSERT_EQ(v, comm_.World());
    }
  }

  void Acc() {
    std::vector<double> data(314, 1.5);
    auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
      for (std::size_t i = 0; i < rhs.size(); ++i) {
        rhs[i] += lhs[i];
      }
    });
    SafeColl(rc);
    for (std::size_t i = 0; i < data.size(); ++i) {
      auto v = data[i];
      ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
    }
  }

  void BitOr() {
    std::vector<std::uint32_t> data(comm_.World(), 0);
    data[comm_.Rank()] = ~std::uint32_t{0};
    auto pcoll = std::shared_ptr<Coll>{new Coll{}};
    auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
                               ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
    SafeColl(rc);
    for (auto v : data) {
      ASSERT_EQ(v, ~std::uint32_t{0});
    }
  }
};

class AllreduceTest : public SocketTest {};
}  // namespace

TEST_F(AllreduceTest, Basic) {
  std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    AllreduceWorker worker{host, port, timeout, n_workers, r};
    worker.Basic();
  });
}

TEST_F(AllreduceTest, Sum) {
  std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    AllreduceWorker worker{host, port, timeout, n_workers, r};
    worker.Acc();
  });
}

TEST_F(AllreduceTest, BitOr) {
  std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    AllreduceWorker worker{host, port, timeout, n_workers, r};
    worker.BitOr();
  });
}

TEST_F(AllreduceTest, Restricted) {
  std::int32_t n_workers = std::min(3u, std::thread::hardware_concurrency());
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    AllreduceWorker worker{host, port, timeout, n_workers, r};
    worker.Restricted();
  });
}

TEST(AllreduceGlobal, Basic) {
  auto n_workers = 3;
  TestDistributedGlobal(n_workers, [&]() {
    std::vector<float> values(n_workers * 2, 0);
    auto rank = GetRank();
    auto s_values = common::Span{values.data(), values.size()};
    auto self = s_values.subspan(rank * 2, 2);
    for (auto& v : self) {
      v = 1.0f;
    }
    Context ctx;
    auto rc =
        Allreduce(&ctx, linalg::MakeVec(s_values.data(), s_values.size()), collective::Op::kSum);
    SafeColl(rc);
    for (auto v : s_values) {
      ASSERT_EQ(v, 1);
    }
  });
}

TEST(AllreduceGlobal, Small) {
  // Test when the data is not large enougth to be divided by the number of workers
  auto n_workers = 8;
  TestDistributedGlobal(n_workers, [&]() {
    std::uint64_t value{1};
    Context ctx;
    auto rc = Allreduce(&ctx, linalg::MakeVec(&value, 1), collective::Op::kSum);
    SafeColl(rc);
    ASSERT_EQ(value, n_workers);
  });
}
}  // namespace xgboost::collective