File: test_gradient_index_page_raw_format.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (62 lines) | stat: -rw-r--r-- 2,441 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
/**
 * Copyright 2021-2024, XGBoost contributors
 */
#include <gtest/gtest.h>
#include <xgboost/context.h>  // for Context

#include <cstddef>  // for size_t
#include <memory>   // for unique_ptr

#include "../../../src/common/column_matrix.h"  // for common::ColumnMatrix
#include "../../../src/common/io.h"             // for MmapResource, AlignedResourceReadStream...
#include "../../../src/data/gradient_index.h"   // for GHistIndexMatrix
#include "../../../src/data/gradient_index_format.h"  // for GHistIndexRawFormat
#include "../helpers.h"                               // for RandomDataGenerator

namespace xgboost::data {
TEST(GHistIndexPageRawFormat, IO) {
  Context ctx;

  auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix();
  dmlc::TemporaryDirectory tmpdir;
  std::string path = tmpdir.path + "/ghistindex.page";
  auto batch = BatchParam{256, 0.5};

  common::HistogramCuts cuts;
  for (auto const &index : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
    cuts = index.Cuts();
    break;
  }
  auto format = std::make_unique<GHistIndexRawFormat>(std::move(cuts));

  std::size_t bytes{0};
  {
    auto fo = std::make_unique<common::AlignedFileWriteStream>(StringView{path}, "wb");
    for (auto const &index : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
      bytes += format->Write(index, fo.get());
    }
  }

  GHistIndexMatrix page;

  std::unique_ptr<common::AlignedResourceReadStream> fi{
      std::make_unique<common::PrivateMmapConstStream>(path, 0, bytes)};
  ASSERT_TRUE(format->Read(&page, fi.get()));

  for (auto const &gidx : m->GetBatches<GHistIndexMatrix>(&ctx, batch)) {
    auto const &loaded = gidx;
    ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs());
    ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues());
    ASSERT_EQ(loaded.cut.Values(), page.cut.Values());
    ASSERT_EQ(loaded.base_rowid, page.base_rowid);
    ASSERT_EQ(loaded.row_ptr.size(), page.row_ptr.size());
    ASSERT_TRUE(std::equal(loaded.row_ptr.cbegin(), loaded.row_ptr.cend(), page.row_ptr.cbegin()));
    ASSERT_EQ(loaded.IsDense(), page.IsDense());
    ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin()));
    ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(),
                           page.index.Offset()));

    ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize());
  }
}
}  // namespace xgboost::data