File: test_partitioner.h

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (45 lines) | stat: -rw-r--r-- 2,021 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/**
 * Copyright 2021-2023 by XGBoost contributors.
 */
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/context.h>                      // for Context
#include <xgboost/linalg.h>                       // for Constant, Vector
#include <xgboost/logging.h>                      // for CHECK
#include <xgboost/tree_model.h>                   // for RegTree

#include <vector>                                 // for vector

#include "../../../src/tree/hist/expand_entry.h"  // for CPUExpandEntry, MultiExpandEntry

namespace xgboost::tree {
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
  CHECK(!tree->IsMultiTarget());
  tree->ExpandNode(
      /*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
      /*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
      /*left_sum=*/0.0f,
      /*right_sum=*/0.0f);
  candidates->front().split.split_value = split_value;
  candidates->front().split.sindex = 0;
  candidates->front().split.sindex |= (1U << 31);
}

inline void GetMultiSplitForTest(RegTree *tree, float split_value,
                                 std::vector<MultiExpandEntry> *candidates) {
  CHECK(tree->IsMultiTarget());
  auto n_targets = tree->NumTargets();
  Context ctx;
  linalg::Vector<float> base_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
  linalg::Vector<float> left_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
  linalg::Vector<float> right_weight{linalg::Constant(&ctx, 0.0f, n_targets)};

  tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
                   /*default_left=*/true, base_weight.HostView(), left_weight.HostView(),
                   right_weight.HostView());
  candidates->front().split.split_value = split_value;
  candidates->front().split.sindex = 0;
  candidates->front().split.sindex |= (1U << 31);
}
}  // namespace xgboost::tree
#endif  // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_