File: test_worker.h

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (86 lines) | stat: -rw-r--r-- 2,633 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
/**
 * Copyright 2022-2023, XGBoost contributors
 */
#pragma once

#include <gtest/gtest.h>

#include <chrono>  // for ms, seconds
#include <memory>  // for shared_ptr
#include <thread>  // for thread

#include "../../../../plugin/federated/federated_tracker.h"
#include "../../../../src/collective/comm_group.h"
#include "../../../../src/collective/communicator-inl.h"
#include "federated_comm.h"  // for FederatedComm
#include "xgboost/json.h"    // for Json

namespace xgboost::collective {
inline Json FederatedTestConfig(std::int32_t n_workers, std::int32_t port, std::int32_t i) {
  Json config{Object{}};
  config["dmlc_communicator"] = std::string{"federated"};
  config["dmlc_task_id"] = std::to_string(i);
  config["dmlc_retry"] = 2;
  config["federated_world_size"] = n_workers;
  config["federated_rank"] = i;
  config["federated_server_address"] = "0.0.0.0:" + std::to_string(port);
  return config;
}

template <typename WorkerFn>
void TestFederatedImpl(std::int32_t n_workers, WorkerFn&& fn) {
  Json config{Object()};
  config["federated_secure"] = Boolean{false};
  config["n_workers"] = Integer{n_workers};
  FederatedTracker tracker{config};
  auto fut = tracker.Run();

  std::vector<std::thread> workers;
  using namespace std::chrono_literals;
  auto rc = tracker.WaitUntilReady();
  SafeColl(rc);
  std::int32_t port = tracker.Port();

  for (std::int32_t i = 0; i < n_workers; ++i) {
    workers.emplace_back([=] { fn(port, i); });
  }

  for (auto& t : workers) {
    t.join();
  }

  rc = tracker.Shutdown();
  SafeColl(rc);
  SafeColl(fut.get());
}

template <typename WorkerFn>
void TestFederated(std::int32_t n_workers, WorkerFn&& fn) {
  TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
    auto config = FederatedTestConfig(n_workers, port, i);
    auto comm = std::make_shared<FederatedComm>(
        DefaultRetry(), std::chrono::seconds{DefaultTimeoutSec()}, std::to_string(i), config);

    fn(comm, i);
  });
}

template <typename WorkerFn>
void TestFederatedGroup(std::int32_t n_workers, WorkerFn&& fn) {
  TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
    auto config = FederatedTestConfig(n_workers, port, i);
    std::shared_ptr<CommGroup> comm_group{CommGroup::Create(config)};
    fn(comm_group, i);
  });
}

template <typename WorkerFn>
void TestFederatedGlobal(std::int32_t n_workers, WorkerFn&& fn) {
  TestFederatedImpl(n_workers, [&](std::int32_t port, std::int32_t i) {
    auto config = FederatedTestConfig(n_workers, port, i);
    collective::Init(config);
    fn();
    collective::Finalize();
  });
}
}  // namespace xgboost::collective