File: helpers.h

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (514 lines) | stat: -rw-r--r-- 17,730 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
/**
 * Copyright 2016-2024, XGBoost contributors
 */
#pragma once

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <xgboost/base.h>
#include <xgboost/context.h>
#include <xgboost/json.h>
#include <xgboost/learner.h>  // for LearnerModelParam
#include <xgboost/model.h>    // for Configurable

#include <cstdint>  // std::int32_t
#include <cstdio>
#include <memory>
#include <string>
#include <vector>

#if defined(__CUDACC__)
#include "../../src/collective/communicator-inl.h"  // for GetRank
#include "../../src/common/cuda_rt_utils.h"         // for AllVisibleGPUs
#endif  // defined(__CUDACC__)

#include "filesystem.h"  // dmlc::TemporaryDirectory
#include "xgboost/linalg.h"

#if defined(__CUDACC__)
#define DeclareUnifiedTest(name) GPU ## name
#else
#define DeclareUnifiedTest(name) name
#endif

#if defined(__CUDACC__)
#define GPUIDX (curt::AllVisibleGPUs() == 1 ? 0 : collective::GetRank())
#else
#define GPUIDX (-1)
#endif

#if defined(__CUDACC__)
#define DeclareUnifiedDistributedTest(name) MGPU ## name
#else
#define DeclareUnifiedDistributedTest(name) name
#endif

namespace xgboost {
class ObjFunction;
class Metric;
struct LearnerModelParam;
class GradientBooster;
}

template <typename Float>
Float RelError(Float l, Float r) {
  static_assert(std::is_floating_point_v<Float>);
  return std::abs(1.0f - l / r);
}

bool FileExists(const std::string& filename);

void CreateSimpleTestData(const std::string& filename);

// Create a libsvm format file with 3 entries per-row. `zero_based` specifies whether it's
// 0-based indexing.
void CreateBigTestData(const std::string& filename, size_t n_entries, bool zero_based = true);

void CreateTestCSV(std::string const& path, size_t rows, size_t cols);

void CheckObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
                      std::vector<xgboost::bst_float> preds,
                      std::vector<xgboost::bst_float> labels,
                      std::vector<xgboost::bst_float> weights,
                      std::vector<xgboost::bst_float> out_grad,
                      std::vector<xgboost::bst_float> out_hess);

xgboost::Json CheckConfigReloadImpl(xgboost::Configurable* const configurable,
                                    std::string name);

template <typename T>
xgboost::Json CheckConfigReload(std::unique_ptr<T> const& configurable,
                                std::string name = "") {
  return CheckConfigReloadImpl(dynamic_cast<xgboost::Configurable*>(configurable.get()),
                               name);
}

void CheckRankingObjFunction(std::unique_ptr<xgboost::ObjFunction> const& obj,
                             std::vector<xgboost::bst_float> preds,
                             std::vector<xgboost::bst_float> labels,
                             std::vector<xgboost::bst_float> weights,
                             std::vector<xgboost::bst_uint> groups,
                             std::vector<xgboost::bst_float> out_grad,
                             std::vector<xgboost::bst_float> out_hess);

xgboost::bst_float GetMetricEval(
  xgboost::Metric * metric,
  xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
  std::vector<xgboost::bst_float> labels,
  std::vector<xgboost::bst_float> weights = std::vector<xgboost::bst_float>(),
  std::vector<xgboost::bst_uint> groups = std::vector<xgboost::bst_uint>(),
  xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow);

double GetMultiMetricEval(xgboost::Metric* metric,
                          xgboost::HostDeviceVector<xgboost::bst_float> const& preds,
                          xgboost::linalg::Tensor<float, 2> const& labels,
                          std::vector<xgboost::bst_float> weights = {},
                          std::vector<xgboost::bst_uint> groups = {},
                          xgboost::DataSplitMode data_split_Mode = xgboost::DataSplitMode::kRow);

namespace xgboost {

float GetBaseScore(Json const &config);

/*!
 * \brief Linear congruential generator.
 *
 * The distribution defined in std is not portable. Given the same seed, it
 * migth produce different outputs on different platforms or with different
 * compilers.  The SimpleLCG implemented here is to make sure all tests are
 * reproducible.
 */
class SimpleLCG {
 private:
  using StateType = uint64_t;
  static StateType constexpr kDefaultInit = 3;
  static StateType constexpr kDefaultAlpha = 61;
  static StateType constexpr kMaxValue = (static_cast<StateType>(1) << 32) - 1;

  StateType state_;
  StateType const alpha_;
  StateType const mod_;

 public:
  using result_type = StateType;  // NOLINT

 public:
  SimpleLCG() : state_{kDefaultInit}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {}
  SimpleLCG(SimpleLCG const& that) = default;
  SimpleLCG(SimpleLCG&& that) = default;

  void Seed(StateType seed) { state_ = seed % mod_; }
  /*!
   * \brief Initialize SimpleLCG.
   *
   * \param state  Initial state, can also be considered as seed. If set to
   *               zero, SimpleLCG will use internal default value.
   */
  explicit SimpleLCG(StateType state)
      : state_{state == 0 ? kDefaultInit : state}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {}

  StateType operator()();
  StateType Min() const;
  StateType Max() const;

  constexpr result_type static min() { return 0; };         // NOLINT
  constexpr result_type static max() { return kMaxValue; }  // NOLINT
};

template <typename ResultT>
class SimpleRealUniformDistribution {
 private:
  ResultT const lower_;
  ResultT const upper_;

  /*! \brief Over-simplified version of std::generate_canonical. */
  template <size_t Bits, typename GeneratorT>
  ResultT GenerateCanonical(GeneratorT* rng) const {
    static_assert(std::is_floating_point_v<ResultT>, "Result type must be floating point.");
    long double const r = (static_cast<long double>(rng->Max())
                           - static_cast<long double>(rng->Min())) + 1.0L;
    auto const log2r = static_cast<size_t>(std::log(r) / std::log(2.0L));
    size_t m = std::max<size_t>(1UL, (Bits + log2r - 1UL) / log2r);
    ResultT sum_value = 0, r_k = 1;

    for (size_t k = m; k != 0; --k) {
      sum_value += static_cast<ResultT>((*rng)() - rng->Min()) * r_k;
      r_k *= static_cast<ResultT>(r);
    }

    ResultT res = sum_value / r_k;
    return res;
  }

 public:
  SimpleRealUniformDistribution(ResultT l, ResultT u) :
      lower_{l}, upper_{u} {}

  template <typename GeneratorT>
  ResultT operator()(GeneratorT* rng) const {
    ResultT tmp = GenerateCanonical<std::numeric_limits<ResultT>::digits,
                                    GeneratorT>(rng);
    auto ret = (tmp * (upper_ - lower_)) + lower_;
    // Correct floating point error.
    return std::max(ret, lower_);
  }
};

template <typename T>
Json GetArrayInterface(HostDeviceVector<T> const* storage, size_t rows, size_t cols) {
  Json array_interface{Object()};
  array_interface["data"] = std::vector<Json>(2);
  if (storage->DeviceCanRead()) {
    array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(storage->ConstDevicePointer())};
    array_interface["stream"] = nullptr;
  } else {
    array_interface["data"][0] = Integer{reinterpret_cast<int64_t>(storage->ConstHostPointer())};
  }
  array_interface["data"][1] = Boolean(false);

  array_interface["shape"] = std::vector<Json>(2);
  array_interface["shape"][0] = rows;
  array_interface["shape"][1] = cols;

  char t = linalg::detail::ArrayInterfaceHandler::TypeChar<T>();
  array_interface["typestr"] = String(std::string{"<"} + t + std::to_string(sizeof(T)));
  array_interface["version"] = 3;
  return array_interface;
}

// Generate in-memory random data without using DMatrix.
class RandomDataGenerator {
  bst_idx_t rows_;
  size_t cols_;
  float sparsity_;

  float lower_{0.0f};
  float upper_{1.0f};

  bst_target_t n_targets_{1};
  bst_target_t n_classes_{0};

  DeviceOrd device_{DeviceOrd::CPU()};
  std::size_t n_batches_{0};
  std::uint64_t seed_{0};
  SimpleLCG lcg_;

  bst_bin_t bins_{0};
  std::vector<FeatureType> ft_;
  bst_cat_t max_cat_{32};
  bool on_host_{false};
  std::shared_ptr<DMatrix> ref_{nullptr};
  std::int64_t min_cache_page_bytes_{0};
  std::int64_t max_num_device_pages_{1};

  Json ArrayInterfaceImpl(HostDeviceVector<float>* storage, size_t rows, size_t cols) const;

  void GenerateLabels(std::shared_ptr<DMatrix> p_fmat) const;

 public:
  RandomDataGenerator(bst_idx_t rows, size_t cols, float sparsity)
      : rows_{rows}, cols_{cols}, sparsity_{sparsity}, lcg_{seed_} {}

  RandomDataGenerator& Lower(float v) {
    lower_ = v;
    return *this;
  }
  RandomDataGenerator& Upper(float v) {
    upper_ = v;
    return *this;
  }
  RandomDataGenerator& Device(DeviceOrd d) {
    device_ = d;
    return *this;
  }
  RandomDataGenerator& Batches(std::size_t n_batches) {
    n_batches_ = n_batches;
    return *this;
  }
  RandomDataGenerator& OnHost(bool on_host) {
    on_host_ = on_host;
    return *this;
  }
  RandomDataGenerator& Ref(std::shared_ptr<DMatrix> ref) {
    this->ref_ = std::move(ref);
    return *this;
  }
  RandomDataGenerator& MinPageCacheBytes(std::int64_t min_cache_page_bytes) {
    this->min_cache_page_bytes_ = min_cache_page_bytes;
    return *this;
  }
  RandomDataGenerator& MaxNumDevicePages(std::int64_t max_num_device_pages) {
    this->max_num_device_pages_ = max_num_device_pages;
    return *this;
  }
  RandomDataGenerator& Seed(uint64_t s) {
    seed_ = s;
    lcg_.Seed(seed_);
    return *this;
  }
  RandomDataGenerator& Bins(bst_bin_t b) {
    bins_ = b;
    return *this;
  }
  RandomDataGenerator& Type(common::Span<FeatureType> ft) {
    CHECK_EQ(ft.size(), cols_);
    ft_.resize(ft.size());
    std::copy(ft.cbegin(), ft.cend(), ft_.begin());
    return *this;
  }
  RandomDataGenerator& MaxCategory(bst_cat_t cat) {
    max_cat_ = cat;
    return *this;
  }
  RandomDataGenerator& Targets(bst_target_t n_targets) {
    n_targets_ = n_targets;
    return *this;
  }
  RandomDataGenerator& Classes(bst_target_t n_classes) {
    n_classes_ = n_classes;
    return *this;
  }

  void GenerateDense(HostDeviceVector<float>* out) const;

  std::string GenerateArrayInterface(HostDeviceVector<float>* storage) const;

  /*!
   * \brief Generate batches of array interface stored in consecutive memory.
   *
   * \param storage The consecutive momory used to store the arrays.
   * \param batches Number of batches.
   *
   * \return A vector storing JSON string representation of interface for each batch, and
   *         a single JSON string representing the consecutive memory as a whole
   *         (combining all the batches).
   */
  std::pair<std::vector<std::string>, std::string> GenerateArrayInterfaceBatch(
      HostDeviceVector<float>* storage, size_t batches) const;

  std::string GenerateColumnarArrayInterface(std::vector<HostDeviceVector<float>>* data) const;

  void GenerateCSR(HostDeviceVector<float>* value, HostDeviceVector<std::size_t>* row_ptr,
                   HostDeviceVector<bst_feature_t>* columns) const;

  [[nodiscard]] std::shared_ptr<DMatrix> GenerateDMatrix(
      bool with_label = false, DataSplitMode data_split_mode = DataSplitMode::kRow) const;

  [[nodiscard]] std::shared_ptr<DMatrix> GenerateSparsePageDMatrix(std::string prefix,
                                                                   bool with_label) const;

  [[nodiscard]] std::shared_ptr<DMatrix> GenerateExtMemQuantileDMatrix(std::string prefix,
                                                                       bool with_label) const;

  std::shared_ptr<DMatrix> GenerateQuantileDMatrix(bool with_label);
};

// Generate an empty DMatrix, mostly for its meta info.
inline std::shared_ptr<DMatrix> EmptyDMatrix() {
  return RandomDataGenerator{0, 0, 0.0}.GenerateDMatrix();
}

inline std::vector<float> GenerateRandomCategoricalSingleColumn(int n, size_t num_categories) {
  std::vector<float> x(n);
  std::mt19937 rng(0);
  std::uniform_int_distribution<size_t> dist(0, num_categories - 1);
  std::generate(x.begin(), x.end(), [&]() { return static_cast<float>(dist(rng)); });
  // Make sure each category is present
  for (size_t i = 0; i < num_categories; i++) {
    x[i] = static_cast<decltype(x)::value_type>(i);
  }
  return x;
}

std::shared_ptr<DMatrix> GetDMatrixFromData(const std::vector<float>& x, std::size_t num_rows,
                                            bst_feature_t num_columns);

std::unique_ptr<GradientBooster> CreateTrainedGBM(std::string name, Args kwargs, size_t kRows,
                                                  size_t kCols,
                                                  LearnerModelParam const* learner_model_param,
                                                  Context const* generic_param);

/**
 * \brief Make a context that uses CUDA if device >= 0.
 */
inline Context MakeCUDACtx(std::int32_t device) {
  if (device == DeviceOrd::CPUOrdinal()) {
    return Context{};
  }
  return Context{}.MakeCUDA(device);
}

inline HostDeviceVector<GradientPair> GenerateRandomGradients(const size_t n_rows,
                                                              float lower = 0.0f,
                                                              float upper = 1.0f) {
  xgboost::SimpleLCG gen;
  xgboost::SimpleRealUniformDistribution<bst_float> dist(lower, upper);
  std::vector<GradientPair> h_gpair(n_rows);
  for (auto& gpair : h_gpair) {
    bst_float grad = dist(&gen);
    bst_float hess = dist(&gen);
    gpair = GradientPair(grad, hess);
  }
  HostDeviceVector<GradientPair> gpair(h_gpair);
  return gpair;
}

inline linalg::Matrix<GradientPair> GenerateRandomGradients(Context const* ctx, bst_idx_t n_rows,
                                                            bst_target_t n_targets,
                                                            float lower = 0.0f,
                                                            float upper = 1.0f) {
  auto g = GenerateRandomGradients(n_rows * n_targets, lower, upper);
  linalg::Matrix<GradientPair> gpair({n_rows, static_cast<bst_idx_t>(n_targets)}, ctx->Device());
  gpair.Data()->Copy(g);
  return gpair;
}

typedef void *DMatrixHandle;  // NOLINT(*);

class ArrayIterForTest {
 protected:
  HostDeviceVector<float> data_;
  size_t iter_{0};
  DMatrixHandle proxy_;
  std::unique_ptr<RandomDataGenerator> rng_;

  std::vector<std::string> batches_;
  std::string interface_;
  bst_idx_t rows_;
  size_t cols_;
  size_t n_batches_;

 public:
  bst_idx_t static constexpr Rows() { return 1024; }
  size_t static constexpr Batches() { return 100; }
  size_t static constexpr Cols() { return 13; }

 public:
  [[nodiscard]] std::string AsArray() const { return interface_; }

  virtual int Next() = 0;
  virtual void Reset() { iter_ = 0; }
  [[nodiscard]] std::size_t Iter() const { return iter_; }
  auto Proxy() -> decltype(proxy_) { return proxy_; }

  explicit ArrayIterForTest(float sparsity, bst_idx_t rows, size_t cols, size_t batches);
  /**
   * \brief Create iterator with user provided data.
   */
  explicit ArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
                            std::size_t n_samples, bst_feature_t n_features, std::size_t n_batches);
  virtual ~ArrayIterForTest();
};

class CudaArrayIterForTest : public ArrayIterForTest {
 public:
  explicit CudaArrayIterForTest(float sparsity, size_t rows = Rows(), size_t cols = Cols(),
                                size_t batches = Batches());
  int Next() override;
  ~CudaArrayIterForTest() override = default;
};

class NumpyArrayIterForTest : public ArrayIterForTest {
 public:
  explicit NumpyArrayIterForTest(float sparsity, bst_idx_t rows = Rows(), size_t cols = Cols(),
                                 size_t batches = Batches());
  explicit NumpyArrayIterForTest(Context const* ctx, HostDeviceVector<float> const& data,
                                 std::size_t n_samples, bst_feature_t n_features,
                                 std::size_t n_batches)
      : ArrayIterForTest{ctx, data, n_samples, n_features, n_batches} {}
  int Next() override;
  ~NumpyArrayIterForTest() override = default;
};

void DMatrixToCSR(DMatrix *dmat, std::vector<float> *p_data,
                  std::vector<size_t> *p_row_ptr,
                  std::vector<bst_feature_t> *p_cids);

typedef void *DataIterHandle;  // NOLINT(*)

inline void Reset(DataIterHandle self) {
  static_cast<ArrayIterForTest*>(self)->Reset();
}

inline int Next(DataIterHandle self) {
  return static_cast<ArrayIterForTest*>(self)->Next();
}

/**
 * @brief Create an array interface for host vector.
 */
template <typename T>
char const* Make1dInterfaceTest(T const* vec, std::size_t len) {
  static thread_local std::string str;
  str = linalg::Make1dInterface(vec, len);
  return str.c_str();
}

class RMMAllocator;
using RMMAllocatorPtr = std::unique_ptr<RMMAllocator, void(*)(RMMAllocator*)>;
RMMAllocatorPtr SetUpRMMResourceForCppTests(int argc, char** argv);

/*
 * \brief Make learner model param
 */
inline LearnerModelParam MakeMP(bst_feature_t n_features, float base_score, uint32_t n_groups,
                                DeviceOrd device = DeviceOrd::CPU()) {
  size_t shape[1]{1};
  LearnerModelParam mparam(n_features, linalg::Tensor<float, 1>{{base_score}, shape, device},
                           n_groups, 1, MultiStrategy::kOneOutputPerTree);
  return mparam;
}

inline std::int32_t AllThreadsForTest() { return Context{}.Threads(); }

inline DeviceOrd FstCU() { return DeviceOrd::CUDA(0); }

// GPU device ordinal for distributed tests
std::int32_t DistGpuIdx();

inline auto GMockThrow(StringView msg) {
  return ::testing::ThrowsMessage<dmlc::Error>(::testing::HasSubstr(msg));
}
}  // namespace xgboost