1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
/**
* Copyright 2021-2023 by XGBoost contributors.
*/
#ifndef XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#define XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
#include <xgboost/context.h> // for Context
#include <xgboost/linalg.h> // for Constant, Vector
#include <xgboost/logging.h> // for CHECK
#include <xgboost/tree_model.h> // for RegTree
#include <vector> // for vector
#include "../../../src/tree/hist/expand_entry.h" // for CPUExpandEntry, MultiExpandEntry
namespace xgboost::tree {
inline void GetSplit(RegTree *tree, float split_value, std::vector<CPUExpandEntry> *candidates) {
CHECK(!tree->IsMultiTarget());
tree->ExpandNode(
/*nid=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f,
/*left_sum=*/0.0f,
/*right_sum=*/0.0f);
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
inline void GetMultiSplitForTest(RegTree *tree, float split_value,
std::vector<MultiExpandEntry> *candidates) {
CHECK(tree->IsMultiTarget());
auto n_targets = tree->NumTargets();
Context ctx;
linalg::Vector<float> base_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> left_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
linalg::Vector<float> right_weight{linalg::Constant(&ctx, 0.0f, n_targets)};
tree->ExpandNode(/*nidx=*/RegTree::kRoot, /*split_index=*/0, /*split_value=*/split_value,
/*default_left=*/true, base_weight.HostView(), left_weight.HostView(),
right_weight.HostView());
candidates->front().split.split_value = split_value;
candidates->front().split.sindex = 0;
candidates->front().split.sindex |= (1U << 31);
}
} // namespace xgboost::tree
#endif // XGBOOST_TESTS_CPP_TREE_TEST_PARTITIONER_H_
|