File: test_broadcast.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (50 lines) | stat: -rw-r--r-- 1,521 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>
#include <xgboost/collective/socket.h>

#include <cstdint>  // for int32_t
#include <string>   // for string
#include <thread>   // for thread
#include <vector>   // for vector

#include "../../../src/collective/broadcast.h"  // for Broadcast
#include "test_worker.h"                        // for WorkerForTest, TestDistributed

namespace xgboost::collective {
namespace {
class Worker : public WorkerForTest {
 public:
  using WorkerForTest::WorkerForTest;

  void Run() {
    for (std::int32_t r = 0; r < comm_.World(); ++r) {
      // basic test
      std::vector<std::int32_t> data(1, comm_.Rank());
      auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
      SafeColl(rc);
      ASSERT_EQ(data[0], r);
    }

    for (std::int32_t r = 0; r < comm_.World(); ++r) {
      std::vector<std::int32_t> data(1 << 16, comm_.Rank());
      auto rc = Broadcast(this->comm_, common::Span{data.data(), data.size()}, r);
      SafeColl(rc);
      ASSERT_EQ(data[0], r);
    }
  }
};

class BroadcastTest : public SocketTest {};
}  // namespace

TEST_F(BroadcastTest, Basic) {
  std::int32_t n_workers = std::min(2u, std::thread::hardware_concurrency());
  TestDistributed(n_workers, [=](std::string host, std::int32_t port, std::chrono::seconds timeout,
                                 std::int32_t r) {
    Worker worker{host, port, timeout, n_workers, r};
    worker.Run();
  });
}
}  // namespace xgboost::collective