File: test_prune.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (89 lines) | stat: -rw-r--r-- 2,977 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
/**
 * Copyright 2018-2023 by XGBoost Contributors
 */
#include <gtest/gtest.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/learner.h>
#include <xgboost/tree_updater.h>

#include <memory>
#include <string>
#include <vector>

#include "../../../src/tree/param.h"  // for TrainParam
#include "../helpers.h"

namespace xgboost::tree {
TEST(Updater, Prune) {
  int constexpr kCols = 16;

  std::vector<std::pair<std::string, std::string>> cfg;
  cfg.emplace_back("num_feature", std::to_string(kCols));
  cfg.emplace_back("min_split_loss", "10");
  Context ctx;

  // These data are just place holders.
  linalg::Matrix<GradientPair> gpair
      {{ {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f}, {0.50f, 0.25f},
         {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f}, {0.25f, 0.24f} }, {8, 1}, ctx.Device()};
  std::shared_ptr<DMatrix> p_dmat{RandomDataGenerator{32, 10, 0}.GenerateDMatrix()};

  // prepare tree
  RegTree tree = RegTree{1u, kCols};
  std::vector<RegTree*> trees {&tree};
  // prepare pruner
  TrainParam param;
  param.UpdateAllowUnknown(cfg);

  ObjInfo task{ObjInfo::kRegression};
  std::unique_ptr<TreeUpdater> pruner(TreeUpdater::Create("prune", &ctx, &task));

  // loss_chg < min_split_loss;
  std::vector<HostDeviceVector<bst_node_t>> position(trees.size());
  tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 0.0f, 0.0f,
                  /*left_sum=*/0.0f, /*right_sum=*/0.0f);
  pruner->Update(&param, &gpair, p_dmat.get(), position, trees);

  ASSERT_EQ(tree.NumExtraNodes(), 0);

  // loss_chg > min_split_loss;
  tree.ExpandNode(0, 0, 0, true, 0.0f, 0.3f, 0.4f, 11.0f, 0.0f,
                  /*left_sum=*/0.0f, /*right_sum=*/0.0f);
  pruner->Update(&param, &gpair, p_dmat.get(), position, trees);

  ASSERT_EQ(tree.NumExtraNodes(), 2);

  // loss_chg == min_split_loss;
  tree.Stat(0).loss_chg = 10;
  pruner->Update(&param, &gpair, p_dmat.get(), position, trees);

  ASSERT_EQ(tree.NumExtraNodes(), 2);

  // Test depth
  // loss_chg > min_split_loss
  tree.ExpandNode(tree[0].LeftChild(),
                  0, 0.5f, true, 0.3, 0.4, 0.5,
                  /*loss_chg=*/18.0f, 0.0f,
                  /*left_sum=*/0.0f, /*right_sum=*/0.0f);
  tree.ExpandNode(tree[0].RightChild(),
                  0, 0.5f, true, 0.3, 0.4, 0.5,
                  /*loss_chg=*/19.0f, 0.0f,
                  /*left_sum=*/0.0f, /*right_sum=*/0.0f);

  cfg.emplace_back("max_depth", "1");
  param.UpdateAllowUnknown(cfg);
  pruner->Update(&param, &gpair, p_dmat.get(), position, trees);
  ASSERT_EQ(tree.NumExtraNodes(), 2);

  tree.ExpandNode(tree[0].LeftChild(),
                  0, 0.5f, true, 0.3, 0.4, 0.5,
                  /*loss_chg=*/18.0f, 0.0f,
                  /*left_sum=*/0.0f, /*right_sum=*/0.0f);
  cfg.emplace_back("min_split_loss", "0");
  param.UpdateAllowUnknown(cfg);

  pruner->Update(&param, &gpair, p_dmat.get(), position, trees);
  ASSERT_EQ(tree.NumExtraNodes(), 2);
}
}  // namespace xgboost::tree