File: test_extmem_quantile_dmatrix.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (61 lines) | stat: -rw-r--r-- 2,343 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
/**
 * Copyright 2024, XGBoost Contributors
 */
#include "test_extmem_quantile_dmatrix.h"  // for TestExtMemQdmBasic

#include <gtest/gtest.h>
#include <xgboost/data.h>  // for BatchParam

#include <algorithm>  // for equal

#include "../../../src/common/column_matrix.h"  // for ColumnMatrix
#include "../../../src/data/gradient_index.h"   // for GHistIndexMatrix
#include "../../../src/tree/param.h"            // for TrainParam

namespace xgboost::data {
namespace {
class ExtMemQuantileDMatrixCpu : public ::testing::TestWithParam<float> {
 public:
  void Run(float sparsity) {
    auto equal = [](Context const*, GHistIndexMatrix const& orig, GHistIndexMatrix const& sparse) {
      // Check the CSR matrix
      auto orig_cuts = orig.Cuts();
      auto sparse_cuts = sparse.Cuts();
      ASSERT_EQ(orig_cuts.Values(), sparse_cuts.Values());
      ASSERT_EQ(orig_cuts.MinValues(), sparse_cuts.MinValues());
      ASSERT_EQ(orig_cuts.Ptrs(), sparse_cuts.Ptrs());

      auto orig_ptr = orig.data.data();
      auto sparse_ptr = sparse.data.data();
      ASSERT_EQ(orig.data.size(), sparse.data.size());

      auto equal = std::equal(orig_ptr, orig_ptr + orig.data.size(), sparse_ptr);
      ASSERT_TRUE(equal);

      // Check the column matrix
      common::ColumnMatrix const& orig_columns = orig.Transpose();
      common::ColumnMatrix const& sparse_columns = sparse.Transpose();

      std::string str_orig, str_sparse;
      common::AlignedMemWriteStream fo_orig{&str_orig}, fo_sparse{&str_sparse};
      auto n_bytes_orig = orig_columns.Write(&fo_orig);
      auto n_bytes_sparse = sparse_columns.Write(&fo_sparse);
      ASSERT_EQ(n_bytes_orig, n_bytes_sparse);
      ASSERT_EQ(str_orig, str_sparse);
    };

    Context ctx;
    TestExtMemQdmBasic<GHistIndexMatrix>(
        &ctx, false, sparsity, equal, [](GHistIndexMatrix const& page) { return page.IsDense(); });
  }
};
}  // anonymous namespace

TEST_P(ExtMemQuantileDMatrixCpu, Basic) { this->Run(this->GetParam()); }

INSTANTIATE_TEST_SUITE_P(ExtMemQuantileDMatrix, ExtMemQuantileDMatrixCpu, ::testing::ValuesIn([] {
                           std::vector<float> sparsities{
                               0.0f, tree::TrainParam::DftSparseThreshold(), 0.4f, 0.8f};
                           return sparsities;
                         }()));
}  // namespace xgboost::data