File: test_cache.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (120 lines) | stat: -rw-r--r-- 3,526 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/**
 * Copyright 2023 by XGBoost contributors
 */
#include <gtest/gtest.h>
#include <xgboost/cache.h>
#include <xgboost/data.h>  // for DMatrix

#include <cstddef>         // for size_t
#include <cstdint>         // for uint32_t
#include <thread>          // for thread

#include "helpers.h"       // for RandomDataGenerator

namespace xgboost {
namespace {
struct CacheForTest {
  std::size_t const i;

  explicit CacheForTest(std::size_t k) : i{k} {}
};
}  // namespace

TEST(DMatrixCache, Basic) {
  std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 4;
  DMatrixCache<CacheForTest> cache{kCacheSize};

  auto add_cache = [&]() {
    // Create a lambda function here, so that p_fmat gets deleted upon the
    // end of the lambda. This is to test how the cache handle expired
    // cache entries.
    auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
    cache.CacheItem(p_fmat, 3);
    DMatrix* m = p_fmat.get();
    return m;
  };
  auto m = add_cache();
  ASSERT_EQ(cache.Container().size(), 0);
  ASSERT_THROW(cache.Entry(m), dmlc::Error);

  auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();

  auto item = cache.CacheItem(p_fmat, 1);
  ASSERT_EQ(cache.Entry(p_fmat.get())->i, 1);

  std::vector<std::shared_ptr<DMatrix>> items;
  for (std::size_t i = 0; i < kCacheSize * 2; ++i) {
    items.emplace_back(RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix());
    cache.CacheItem(items.back(), i);
    ASSERT_EQ(cache.Entry(items.back().get())->i, i);
    ASSERT_LE(cache.Container().size(), kCacheSize);
    if (i > kCacheSize) {
      auto k = i - kCacheSize - 1;
      ASSERT_THROW(cache.Entry(items[k].get()), dmlc::Error);
    }
  }
}

TEST(DMatrixCache, MultiThread) {
  std::size_t constexpr kRows = 2, kCols = 1, kCacheSize = 3;
  auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();

#if defined(__linux__)
  auto const n = std::thread::hardware_concurrency() * 128;
#else
  auto const n = std::thread::hardware_concurrency();
#endif
  CHECK_NE(n, 0);
  std::vector<std::shared_ptr<CacheForTest>> results(n);

  {
    DMatrixCache<CacheForTest> cache{kCacheSize};
    std::vector<std::thread> tasks;
    for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
      tasks.emplace_back([&, i = tidx]() {
        cache.CacheItem(p_fmat, i);

        auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
        results[i] = cache.CacheItem(p_fmat_local, i);
      });
    }
    for (auto& t : tasks) {
      t.join();
    }
    for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
      ASSERT_EQ(results[tidx]->i, tidx);
    }

    tasks.clear();

    for (std::int32_t tidx = static_cast<std::int32_t>(n - 1); tidx >= 0; --tidx) {
      tasks.emplace_back([&, i = tidx]() {
        cache.CacheItem(p_fmat, i);

        auto p_fmat_local = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix();
        results[i] = cache.CacheItem(p_fmat_local, i);
      });
    }
    for (auto& t : tasks) {
      t.join();
    }
    for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
      ASSERT_EQ(results[tidx]->i, tidx);
    }
  }

  {
    DMatrixCache<CacheForTest> cache{n};
    std::vector<std::thread> tasks;
    for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
      tasks.emplace_back([&, tidx]() { results[tidx] = cache.CacheItem(p_fmat, tidx); });
    }
    for (auto& t : tasks) {
      t.join();
    }
    for (std::uint32_t tidx = 0; tidx < n; ++tidx) {
      ASSERT_EQ(results[tidx]->i, tidx);
    }
  }
}
}  // namespace xgboost