1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
|
/**
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <numeric> // for iota
#include "../../../src/collective/allreduce.h"
#include "../../../src/collective/coll.h" // for Coll
#include "../../../src/common/type.h" // for EraseType
#include "test_worker.h" // for WorkerForTest, TestDistributed
namespace xgboost::collective {
namespace {
class AllreduceWorker : public WorkerForTest {
public:
using WorkerForTest::WorkerForTest;
void Basic() {
{
std::vector<double> data(13, 0.0);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
SafeColl(rc);
ASSERT_EQ(std::accumulate(data.cbegin(), data.cend(), 0.0), 0.0);
}
{
std::vector<double> data(1, 1.0);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
SafeColl(rc);
ASSERT_EQ(data[0], static_cast<double>(comm_.World()));
}
}
void Restricted() {
this->LimitSockBuf(4096);
std::size_t n = 4096 * 4;
std::vector<std::int32_t> data(comm_.World() * n, 1);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
SafeColl(rc);
for (auto v : data) {
ASSERT_EQ(v, comm_.World());
}
}
void Acc() {
std::vector<double> data(314, 1.5);
auto rc = Allreduce(comm_, common::Span{data.data(), data.size()}, [](auto lhs, auto rhs) {
for (std::size_t i = 0; i < rhs.size(); ++i) {
rhs[i] += lhs[i];
}
});
SafeColl(rc);
for (std::size_t i = 0; i < data.size(); ++i) {
auto v = data[i];
ASSERT_EQ(v, 1.5 * static_cast<double>(comm_.World())) << i;
}
}
void BitOr() {
std::vector<std::uint32_t> data(comm_.World(), 0);
data[comm_.Rank()] = ~std::uint32_t{0};
auto pcoll = std::shared_ptr<Coll>{new Coll{}};
auto rc = pcoll->Allreduce(comm_, common::EraseType(common::Span{data.data(), data.size()}),
ArrayInterfaceHandler::kU4, Op::kBitwiseOR);
SafeColl(rc);
for (auto v : data) {
ASSERT_EQ(v, ~std::uint32_t{0});
}
}
};
class AllreduceTest : public SocketTest {};
} // namespace
TEST_F(AllreduceTest, Basic) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Basic();
});
}
TEST_F(AllreduceTest, Sum) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Acc();
});
}
TEST_F(AllreduceTest, BitOr) {
std::int32_t n_workers = std::min(7u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.BitOr();
});
}
TEST_F(AllreduceTest, Restricted) {
std::int32_t n_workers = std::min(3u, std::thread::hardware_concurrency());
TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
std::int32_t r) {
AllreduceWorker worker{host, port, timeout, n_workers, r};
worker.Restricted();
});
}
TEST(AllreduceGlobal, Basic) {
auto n_workers = 3;
TestDistributedGlobal(n_workers, [&]() {
std::vector<float> values(n_workers * 2, 0);
auto rank = GetRank();
auto s_values = common::Span{values.data(), values.size()};
auto self = s_values.subspan(rank * 2, 2);
for (auto& v : self) {
v = 1.0f;
}
Context ctx;
auto rc =
Allreduce(&ctx, linalg::MakeVec(s_values.data(), s_values.size()), collective::Op::kSum);
SafeColl(rc);
for (auto v : s_values) {
ASSERT_EQ(v, 1);
}
});
}
TEST(AllreduceGlobal, Small) {
// Test when the data is not large enougth to be divided by the number of workers
auto n_workers = 8;
TestDistributedGlobal(n_workers, [&]() {
std::uint64_t value{1};
Context ctx;
auto rc = Allreduce(&ctx, linalg::MakeVec(&value, 1), collective::Op::kSum);
SafeColl(rc);
ASSERT_EQ(value, n_workers);
});
}
} // namespace xgboost::collective
|