File: test_helpers.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (96 lines) | stat: -rw-r--r-- 3,590 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#include <gtest/gtest.h>
#include <algorithm>

#include "helpers.h"
#include "../../src/data/array_interface.h"
namespace xgboost {

TEST(RandomDataGenerator, DMatrix) {
  size_t constexpr kRows { 16 }, kCols { 32 };
  float constexpr kSparsity { 0.4f };
  auto p_dmatrix = RandomDataGenerator{kRows, kCols, kSparsity}.GenerateDMatrix();

  HostDeviceVector<float> csr_value;
  HostDeviceVector<std::size_t> csr_rptr;
  HostDeviceVector<bst_feature_t> csr_cidx;
  RandomDataGenerator{kRows, kCols, kSparsity}.GenerateCSR(&csr_value, &csr_rptr, &csr_cidx);

  HostDeviceVector<float> dense_data;
  RandomDataGenerator{kRows, kCols, kSparsity}.GenerateDense(&dense_data);

  auto it = std::copy_if(
      dense_data.HostVector().begin(), dense_data.HostVector().end(),
      dense_data.HostVector().begin(), [](float v) { return !std::isnan(v); });

  CHECK_EQ(p_dmatrix->Info().num_row_, kRows);
  CHECK_EQ(p_dmatrix->Info().num_col_, kCols);

  for (auto const& page : p_dmatrix->GetBatches<SparsePage>()) {
    size_t n_elements = page.data.Size();
    CHECK_EQ(n_elements, it - dense_data.HostVector().begin());
    CHECK_EQ(n_elements, csr_value.Size());

    for (size_t i = 0; i < n_elements; ++i) {
      CHECK_EQ(dense_data.HostVector()[i], csr_value.HostVector()[i]);
      CHECK_EQ(dense_data.HostVector()[i], page.data.HostVector()[i].fvalue);
      CHECK_EQ(page.data.HostVector()[i].index, csr_cidx.HostVector()[i]);
    }
    CHECK_EQ(page.offset.Size(), csr_rptr.Size());
    for (size_t i = 0; i < p_dmatrix->Info().num_row_; ++i) {
      CHECK_EQ(page.offset.HostVector()[i], csr_rptr.HostVector()[i]);
    }
  }
}

TEST(RandomDataGenerator, GenerateArrayInterfaceBatch) {
  size_t constexpr kRows { 937 }, kCols { 100 }, kBatches { 13 };
  float constexpr kSparsity { 0.4f };

  HostDeviceVector<float> storage;
  std::string array;
  std::vector<std::string> batches;
  std::tie(batches, array) =
      RandomDataGenerator{kRows, kCols, kSparsity}.GenerateArrayInterfaceBatch(
          &storage, kBatches);
  CHECK_EQ(batches.size(), kBatches);

  size_t rows = 0;
  for (auto const &interface_str : batches) {
    Json j_interface =
        Json::Load({interface_str.c_str(), interface_str.size()});
    ArrayInterfaceHandler::Validate(get<Object const>(j_interface));
    CHECK_EQ(get<Integer>(j_interface["shape"][1]), kCols);
    rows += get<Integer>(j_interface["shape"][0]);
  }
  CHECK_EQ(rows, kRows);
  auto j_array = Json::Load({array.c_str(), array.size()});
  CHECK_EQ(get<Integer>(j_array["shape"][0]), kRows);
  CHECK_EQ(get<Integer>(j_array["shape"][1]), kCols);
}

TEST(RandomDataGenerator, SparseDMatrix) {
  bst_idx_t constexpr kCols{100}, kBatches{13};
  bst_idx_t n_samples{kBatches * 128};
  dmlc::TemporaryDirectory tmpdir;
  auto prefix = tmpdir.path + "/cache";
  auto p_ext_fmat =
      RandomDataGenerator{n_samples, kCols, 0.0}.Batches(kBatches).GenerateSparsePageDMatrix(prefix,
                                                                                             true);

  auto p_fmat = RandomDataGenerator{n_samples, kCols, 0.0}.GenerateDMatrix(true);

  SparsePage concat;
  std::int32_t n_batches{0};
  for (auto const& page : p_ext_fmat->GetBatches<SparsePage>()) {
    concat.Push(page);
    ++n_batches;
  }
  ASSERT_EQ(n_batches, kBatches);
  ASSERT_EQ(concat.Size(), n_samples);

  for (auto const& page : p_fmat->GetBatches<SparsePage>()) {
    ASSERT_EQ(page.data.ConstHostVector(), concat.data.ConstHostVector());
    ASSERT_EQ(page.offset.ConstHostVector(), concat.offset.ConstHostVector());
  }
}
}  // namespace xgboost