File: test_federated_coll.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (93 lines) | stat: -rw-r--r-- 3,463 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
/**
 * Copyright 2022-2023, XGBoost contributors
 */
#include <gtest/gtest.h>
#include <xgboost/span.h>  // for Span

#include <array>  // for array

#include "../../../../src/common/type.h"   // for EraseType
#include "../../collective/test_worker.h"  // for SocketTest
#include "federated_coll.h"                // for FederatedColl
#include "federated_comm.h"                // for FederatedComm
#include "test_worker.h"                   // for TestFederated

namespace xgboost::collective {
namespace {
class FederatedCollTest : public SocketTest {};
}  // namespace

TEST_F(FederatedCollTest, Allreduce) {
  std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
  TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
    std::array<std::int32_t, 5> buffer = {1, 2, 3, 4, 5};
    std::array<std::int32_t, 5> expected;
    std::transform(buffer.cbegin(), buffer.cend(), expected.begin(),
                   [=](auto i) { return i * n_workers; });

    auto coll = std::make_shared<FederatedColl>();
    auto rc = coll->Allreduce(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}),
                              ArrayInterfaceHandler::kI4, Op::kSum);
    SafeColl(rc);
    for (auto i = 0; i < 5; i++) {
      ASSERT_EQ(buffer[i], expected[i]);
    }
  });
}

TEST_F(FederatedCollTest, Broadcast) {
  std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
  TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
    FederatedColl coll{};
    auto rc = Success();
    if (comm->Rank() == 0) {
      std::string buffer{"hello"};
      rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
      ASSERT_EQ(buffer, "hello");
    } else {
      std::string buffer{"     "};
      rc = coll.Broadcast(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}), 0);
      ASSERT_EQ(buffer, "hello");
    }
    SafeColl(rc);
  });
}

TEST_F(FederatedCollTest, Allgather) {
  std::int32_t n_workers = std::min(std::thread::hardware_concurrency(), 3u);
  TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
    FederatedColl coll{};

    std::vector<std::int32_t> buffer(n_workers, 0);
    buffer[comm->Rank()] = comm->Rank();
    auto rc = coll.Allgather(*comm, common::EraseType(common::Span{buffer.data(), buffer.size()}));
    SafeColl(rc);
    for (auto i = 0; i < n_workers; i++) {
      ASSERT_EQ(buffer[i], i);
    }
  });
}

TEST_F(FederatedCollTest, AllgatherV) {
  std::int32_t n_workers = 2;
  TestFederated(n_workers, [=](std::shared_ptr<FederatedComm> comm, std::int32_t) {
    FederatedColl coll{};

    std::vector<std::string_view> inputs{"Federated", " Learning!!!"};
    std::vector<std::int64_t> recv_segments(inputs.size() + 1, 0);
    std::string r;
    std::vector<std::int64_t> sizes{static_cast<std::int64_t>(inputs[0].size()),
                                    static_cast<std::int64_t>(inputs[1].size())};
    r.resize(sizes[0] + sizes[1]);

    auto rc = coll.AllgatherV(
        *comm,
        common::EraseType(common::Span{inputs[comm->Rank()].data(), inputs[comm->Rank()].size()}),
        common::Span{sizes.data(), sizes.size()}, recv_segments,
        common::EraseType(common::Span{r.data(), r.size()}), AllgatherVAlgo::kRing);

    EXPECT_EQ(r, "Federated Learning!!!");
    SafeColl(rc);
  });
}
}  // namespace xgboost::collective