File: test_comm.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (52 lines) | stat: -rw-r--r-- 1,484 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>

#include "../../../src/collective/comm.h"
#include "../../../src/common/type.h"  // for EraseType
#include "test_worker.h"               // for TrackerTest

namespace xgboost::collective {
namespace {
class CommTest : public TrackerTest {};
}  // namespace

TEST_F(CommTest, Channel) {
  auto n_workers = 4;
  RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
  auto fut = tracker.Run();

  std::vector<std::thread> workers;
  std::int32_t port = tracker.Port();

  for (std::int32_t i = 0; i < n_workers; ++i) {
    workers.emplace_back([=] {
      WorkerForTest worker{host, port, timeout, n_workers, i};
      if (i % 2 == 0) {
        auto p_chan = worker.Comm().Chan(i + 1);
        auto rc = Success() << [&] {
          return p_chan->SendAll(
              EraseType(common::Span<std::int32_t const>{&i, static_cast<std::size_t>(1)}));
        } << [&] { return p_chan->Block(); };
        SafeColl(rc);
      } else {
        auto p_chan = worker.Comm().Chan(i - 1);
        std::int32_t r{-1};
        auto rc = Success() << [&] {
          return p_chan->RecvAll(
              EraseType(common::Span<std::int32_t>{&r, static_cast<std::size_t>(1)}));
        } << [&] { return p_chan->Block(); };
        SafeColl(rc);
        ASSERT_EQ(r, i - 1);
      }
    });
  }

  for (auto &w : workers) {
    w.join();
  }

  SafeColl(fut.get());
}
}  // namespace xgboost::collective