File: test_federated_learner.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (122 lines) | stat: -rw-r--r-- 4,082 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
/**
 * Copyright 2023-2024, XGBoost contributors
 *
 * Some other tests for federated learning are in the main test suite (test_learner.cc).
 */
#include <dmlc/parameter.h>
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/objective.h>

#include "../../../../src/collective/communicator-inl.h"
#include "../../../../src/common/linalg_op.h"  // for begin, end
#include "../../helpers.h"
#include "../../objective_helpers.h"  // for MakeObjNamesForTest, ObjTestNameGenerator
#include "test_worker.h"

namespace xgboost {
namespace {
auto MakeModel(std::string tree_method, std::string device, std::string objective,
               std::shared_ptr<DMatrix> dmat) {
  std::unique_ptr<Learner> learner{Learner::Create({dmat})};
  learner->SetParam("tree_method", tree_method);
  learner->SetParam("device", device);
  learner->SetParam("objective", objective);
  if (objective.find("quantile") != std::string::npos) {
    learner->SetParam("quantile_alpha", "0.5");
  }
  if (objective.find("multi") != std::string::npos) {
    learner->SetParam("num_class", "3");
  }
  learner->UpdateOneIter(0, dmat);
  Json config{Object{}};
  learner->SaveConfig(&config);

  Json model{Object{}};
  learner->SaveModel(&model);
  return model;
}

void VerifyObjective(std::size_t rows, std::size_t cols, float expected_base_score,
                     Json expected_model, std::string const &tree_method, std::string device,
                     std::string const &objective) {
  auto rank = collective::GetRank();
  std::shared_ptr<DMatrix> dmat{RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(rank == 0)};

  if (rank == 0) {
    MakeLabelForObjTest(dmat, objective);
  }
  std::shared_ptr<DMatrix> sliced{dmat->SliceCol(collective::GetWorldSize(), rank)};

  auto model = MakeModel(tree_method, device, objective, sliced);
  auto base_score = GetBaseScore(model);
  ASSERT_EQ(base_score, expected_base_score) << " rank " << rank;
  ASSERT_EQ(model, expected_model) << " rank " << rank;
}
}  // namespace

class VerticalFederatedLearnerTest : public ::testing::TestWithParam<std::string> {
  static int constexpr kWorldSize{3};

 protected:
  void Run(std::string tree_method, std::string device, std::string objective) {
    static auto constexpr kRows{16};
    static auto constexpr kCols{16};

    std::shared_ptr<DMatrix> dmat{RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true)};
    MakeLabelForObjTest(dmat, objective);

    auto &h_upper = dmat->Info().labels_upper_bound_.HostVector();
    auto &h_lower = dmat->Info().labels_lower_bound_.HostVector();
    h_lower.resize(kRows);
    h_upper.resize(kRows);
    for (size_t i = 0; i < kRows; ++i) {
      h_lower[i] = 1;
      h_upper[i] = 10;
    }
    if (objective.find("rank:") != std::string::npos) {
      auto h_label = dmat->Info().labels.HostView();
      std::size_t k = 0;
      for (auto &v : h_label) {
        v = k % 2 == 0;
        ++k;
      }
    }

    auto model = MakeModel(tree_method, device, objective, dmat);
    auto score = GetBaseScore(model);
    collective::TestFederatedGlobal(kWorldSize, [&]() {
      VerifyObjective(kRows, kCols, score, model, tree_method, device, objective);
    });
  }
};

TEST_P(VerticalFederatedLearnerTest, Approx) {
  std::string objective = GetParam();
  this->Run("approx", "cpu", objective);
}

TEST_P(VerticalFederatedLearnerTest, Hist) {
  std::string objective = GetParam();
  this->Run("hist", "cpu", objective);
}

#if defined(XGBOOST_USE_CUDA)
TEST_P(VerticalFederatedLearnerTest, GPUApprox) {
  std::string objective = GetParam();
  this->Run("approx", "cuda:0", objective);
}

TEST_P(VerticalFederatedLearnerTest, GPUHist) {
  std::string objective = GetParam();
  this->Run("hist", "cuda:0", objective);
}
#endif  // defined(XGBOOST_USE_CUDA)

INSTANTIATE_TEST_SUITE_P(
    FederatedLearnerObjective, VerticalFederatedLearnerTest,
    ::testing::ValuesIn(MakeObjNamesForTest()),
    [](const ::testing::TestParamInfo<VerticalFederatedLearnerTest::ParamType> &info) {
      return ObjTestNameGenerator(info);
    });
}  // namespace xgboost