1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
|
/**
* Copyright 2023, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/c_api.h>
#include <chrono> // for ""s
#include <thread> // for thread
#include "../../../src/collective/allgather.h" // for RingAllgather
#include "../../../src/collective/tracker.h"
#include "test_worker.h" // for SocketTest
#include "xgboost/json.h" // for Json
namespace xgboost::collective {
namespace {
class TrackerAPITest : public SocketTest {};
} // namespace
TEST_F(TrackerAPITest, CAPI) {
TrackerHandle handle;
Json config{Object{}};
std::int32_t n_workers{2};
config["dmlc_communicator"] = String{"rabit"};
config["n_workers"] = n_workers;
config["timeout"] = 1;
auto config_str = Json::Dump(config);
auto rc = XGTrackerCreate(config_str.c_str(), &handle);
ASSERT_EQ(rc, 0);
rc = XGTrackerRun(handle, nullptr);
ASSERT_EQ(rc, 0);
std::thread bg_wait{[&] {
Json config{Object{}};
auto config_str = Json::Dump(config);
auto rc = XGTrackerWaitFor(handle, config_str.c_str());
ASSERT_EQ(rc, 0);
}};
char const* cargs;
rc = XGTrackerWorkerArgs(handle, &cargs);
ASSERT_EQ(rc, 0);
auto args = Json::Load(StringView{cargs});
std::string host;
SafeColl(GetHostAddress(&host));
ASSERT_EQ(host, get<String const>(args["dmlc_tracker_uri"]));
auto port = get<Integer const>(args["dmlc_tracker_port"]);
ASSERT_NE(port, 0);
std::vector<std::thread> workers;
using std::chrono_literals::operator""s;
for (std::int32_t r = 0; r < n_workers; ++r) {
workers.emplace_back([=] {
WorkerForTest w{host, static_cast<std::int32_t>(port), 8s, n_workers, r};
// basic test
std::vector<std::int32_t> data(w.Comm().World(), 0);
data[w.Comm().Rank()] = w.Comm().Rank();
auto rc = RingAllgather(w.Comm(), common::Span{data.data(), data.size()});
SafeColl(rc);
for (std::int32_t r = 0; r < w.Comm().World(); ++r) {
ASSERT_EQ(data[r], r);
}
});
}
for (auto& w : workers) {
w.join();
}
rc = XGTrackerFree(handle);
ASSERT_EQ(rc, 0);
bg_wait.join();
}
} // namespace xgboost::collective
|