File: test_expand_entry.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (58 lines) | stat: -rw-r--r-- 2,177 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/**
 * Copyright 2023-2024, XGBoost Contributors
 */
#include <gtest/gtest.h>
#include <xgboost/json.h>        // for Json
#include <xgboost/tree_model.h>  // for RegTree

#include "../../../../src/common/categorical.h"  // for CatBitField
#include "../../../../src/tree/hist/expand_entry.h"

namespace xgboost::tree {
TEST(ExpandEntry, IO) {
  CPUExpandEntry entry{RegTree::kRoot, 0};
  entry.split.Update(1.0, 1, /*new_split_value=*/0.3, true, true, GradStats{1.0, 1.0},
                     GradStats{2.0, 2.0});
  bst_bin_t n_bins_feature = 256;
  auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
  entry.split.cat_bits = decltype(entry.split.cat_bits)(n, 0);
  common::CatBitField cat_bits{entry.split.cat_bits};
  cat_bits.Set(n_bins_feature / 2);

  Json je{Object{}};
  entry.Save(&je);

  CPUExpandEntry loaded;
  loaded.Load(je);

  ASSERT_EQ(loaded.split.is_cat, entry.split.is_cat);
  ASSERT_EQ(loaded.split.cat_bits, entry.split.cat_bits);
  ASSERT_EQ(loaded.split.left_sum.GetGrad(), entry.split.left_sum.GetGrad());
  ASSERT_EQ(loaded.split.right_sum.GetHess(), entry.split.right_sum.GetHess());
}

TEST(ExpandEntry, IOMulti) {
  MultiExpandEntry entry{RegTree::kRoot, 0};
  auto left_sum = std::vector<GradientPairPrecise>{{1.0, 1.0}, {1.0, 1.0}};
  auto right_sum = std::vector<GradientPairPrecise>{{2.0, 2.0}, {2.0, 2.0}};
  entry.split.Update(1.0, 1, /*new_split_value=*/0.3, true, true,
                     linalg::MakeVec(left_sum.data(), left_sum.size()),
                     linalg::MakeVec(right_sum.data(), right_sum.size()));
  bst_bin_t n_bins_feature = 256;
  auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
  entry.split.cat_bits = decltype(entry.split.cat_bits)(n, 0);
  common::CatBitField cat_bits{entry.split.cat_bits};
  cat_bits.Set(n_bins_feature / 2);

  Json je{Object{}};
  entry.Save(&je);

  MultiExpandEntry loaded;
  loaded.Load(je);

  ASSERT_EQ(loaded.split.is_cat, entry.split.is_cat);
  ASSERT_EQ(loaded.split.cat_bits, entry.split.cat_bits);
  ASSERT_EQ(loaded.split.left_sum, entry.split.left_sum);
  ASSERT_EQ(loaded.split.right_sum, entry.split.right_sum);
}
}  // namespace xgboost::tree