File: test_coll_c_api.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (77 lines) | stat: -rw-r--r-- 2,118 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
/**
 * Copyright 2023, XGBoost Contributors
 */
#include <gtest/gtest.h>
#include <xgboost/c_api.h>

#include <chrono>  // for ""s
#include <thread>  // for thread

#include "../../../src/collective/allgather.h"  // for RingAllgather
#include "../../../src/collective/tracker.h"
#include "test_worker.h"   // for SocketTest
#include "xgboost/json.h"  // for Json

namespace xgboost::collective {
namespace {
class TrackerAPITest : public SocketTest {};
}  // namespace

TEST_F(TrackerAPITest, CAPI) {
  TrackerHandle handle;
  Json config{Object{}};
  std::int32_t n_workers{2};
  config["dmlc_communicator"] = String{"rabit"};
  config["n_workers"] = n_workers;
  config["timeout"] = 1;
  auto config_str = Json::Dump(config);
  auto rc = XGTrackerCreate(config_str.c_str(), &handle);
  ASSERT_EQ(rc, 0);
  rc = XGTrackerRun(handle, nullptr);
  ASSERT_EQ(rc, 0);

  std::thread bg_wait{[&] {
    Json config{Object{}};
    auto config_str = Json::Dump(config);
    auto rc = XGTrackerWaitFor(handle, config_str.c_str());
    ASSERT_EQ(rc, 0);
  }};

  char const* cargs;
  rc = XGTrackerWorkerArgs(handle, &cargs);
  ASSERT_EQ(rc, 0);
  auto args = Json::Load(StringView{cargs});

  std::string host;
  SafeColl(GetHostAddress(&host));
  ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
  auto port = get<Integer const>(args["dmlc_tracker_port"]);
  ASSERT_NE(port, 0);

  std::vector<std::thread> workers;
  using std::chrono_literals::operator""s;
  for (std::int32_t r = 0; r < n_workers; ++r) {
    workers.emplace_back([=] {
      WorkerForTest w{host, static_cast<std::int32_t>(port), 8s, n_workers, r};
      // basic test
      std::vector<std::int32_t> data(w.Comm().World(), 0);
      data[w.Comm().Rank()] = w.Comm().Rank();

      auto rc = RingAllgather(w.Comm(), common::Span{data.data(), data.size()});
      SafeColl(rc);

      for (std::int32_t r = 0; r < w.Comm().World(); ++r) {
        ASSERT_EQ(data[r], r);
      }
    });
  }
  for (auto& w : workers) {
    w.join();
  }

  rc = XGTrackerFree(handle);
  ASSERT_EQ(rc, 0);

  bg_wait.join();
}
}  // namespace xgboost::collective