File: test_tracker.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (124 lines) | stat: -rw-r--r-- 3,342 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <chrono>   // for seconds
#include <cstdint>  // for int32_t
#include <string>   // for string
#include <thread>   // for thread
#include <vector>   // for vector

#include "../../../src/collective/comm.h"
#include "../helpers.h"  // for GMockThrow
#include "test_worker.h"

namespace xgboost::collective {
namespace {
class PrintWorker : public WorkerForTest {
 public:
  using WorkerForTest::WorkerForTest;

  void Print() {
    auto rc = comm_.LogTracker("ack:" + std::to_string(this->comm_.Rank()));
    SafeColl(rc);
  }
};
}  // namespace

TEST_F(TrackerTest, Bootstrap) {
  RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
  ASSERT_TRUE(HasTimeout(tracker.Timeout()));
  ASSERT_FALSE(tracker.Ready());
  auto fut = tracker.Run();

  std::vector<std::thread> workers;

  auto args = tracker.WorkerArgs();
  ASSERT_TRUE(tracker.Ready());
  ASSERT_EQ(get<String const>(args["dmlc_tracker_uri"]), host);

  std::int32_t port = tracker.Port();

  for (std::int32_t i = 0; i < n_workers; ++i) {
    workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
  }
  for (auto &w : workers) {
    w.join();
  }
  SafeColl(fut.get());

  ASSERT_FALSE(HasTimeout(std::chrono::seconds{-1}));
  ASSERT_FALSE(HasTimeout(std::chrono::seconds{0}));
}

TEST_F(TrackerTest, Print) {
  RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
  auto fut = tracker.Run();

  std::vector<std::thread> workers;
  auto rc = tracker.WaitUntilReady();
  SafeColl(rc);

  std::int32_t port = tracker.Port();

  for (std::int32_t i = 0; i < n_workers; ++i) {
    workers.emplace_back([=] {
      PrintWorker worker{host, port, timeout, n_workers, i};
      worker.Print();
    });
  }

  for (auto &w : workers) {
    w.join();
  }

  SafeColl(fut.get());
}

TEST_F(TrackerTest, GetHostAddress) { ASSERT_TRUE(host.find("127.") == std::string::npos); }

/**
 * Test connecting the tracker after it has finished. This should not hang the workers.
 */
TEST_F(TrackerTest, AfterShutdown) {
  RabitTracker tracker{MakeTrackerConfig(host, n_workers, timeout)};
  auto fut = tracker.Run();

  std::vector<std::thread> workers;
  auto rc = tracker.WaitUntilReady();
  SafeColl(rc);

  std::int32_t port = tracker.Port();

  // Launch no-op workers to cause the tracker to shutdown.
  for (std::int32_t i = 0; i < n_workers; ++i) {
    workers.emplace_back([=] { WorkerForTest worker{host, port, timeout, n_workers, i}; });
  }

  for (auto &w : workers) {
    w.join();
  }

  SafeColl(fut.get());

  // Launch workers again, they should fail.
  workers.clear();
  for (std::int32_t i = 0; i < n_workers; ++i) {
    auto assert_that = [=] {
      WorkerForTest worker{host, port, timeout, n_workers, i};
    };
    // On a Linux platform, the connection will be refused, on Apple platform, this gets
    // an operation now in progress poll failure, on Windows, it's a timeout error.
#if defined(__linux__)
    workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Connection refused")); });
#else
    workers.emplace_back([=] { ASSERT_THAT(assert_that, GMockThrow("Failed to connect to")); });
#endif
  }
  for (auto &w : workers) {
    w.join();
  }
}
}  // namespace xgboost::collective