File: test_gpu_approx.cu

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (65 lines) | stat: -rw-r--r-- 2,308 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
/**
 * Copyright 2024, XGBoost contributors
 */
#include <gtest/gtest.h>
#include <xgboost/json.h>          // for Json
#include <xgboost/task.h>          // for ObjInfo
#include <xgboost/tree_model.h>    // for RegTree
#include <xgboost/tree_updater.h>  // for TreeUpdater

#include "../../../src/tree/param.h"    // for TrainParam
#include "../collective/test_worker.h"  // for BaseMGPUTest
#include "../helpers.h"                 // for GenerateRandomGradients

namespace xgboost::tree {
namespace {
RegTree GetApproxTree(Context const* ctx, DMatrix* dmat) {
  ObjInfo task{ObjInfo::kRegression};
  std::unique_ptr<TreeUpdater> approx_maker{TreeUpdater::Create("grow_gpu_approx", ctx, &task)};
  approx_maker->Configure(Args{});

  TrainParam param;
  param.UpdateAllowUnknown(Args{});

  linalg::Matrix<GradientPair> gpair({dmat->Info().num_row_}, ctx->Device());
  gpair.Data()->Copy(GenerateRandomGradients(dmat->Info().num_row_));

  std::vector<HostDeviceVector<bst_node_t>> position(1);
  RegTree tree;
  approx_maker->Update(&param, &gpair, dmat, common::Span<HostDeviceVector<bst_node_t>>{position},
                       {&tree});
  return tree;
}

void VerifyApproxColumnSplit(bst_idx_t rows, bst_feature_t cols, RegTree const& expected_tree) {
  auto ctx = MakeCUDACtx(DistGpuIdx());

  auto Xy = RandomDataGenerator{rows, cols, 0}.GenerateDMatrix(true);
  auto const world_size = collective::GetWorldSize();
  auto const rank = collective::GetRank();
  std::unique_ptr<DMatrix> sliced{Xy->SliceCol(world_size, rank)};

  RegTree tree = GetApproxTree(&ctx, sliced.get());

  Json json{Object{}};
  tree.SaveModel(&json);
  Json expected_json{Object{}};
  expected_tree.SaveModel(&expected_json);
  ASSERT_EQ(json, expected_json);
}
}  // anonymous namespace

class MGPUApproxTest : public collective::BaseMGPUTest {};

TEST_F(MGPUApproxTest, GPUApproxColumnSplit) {
  auto constexpr kRows = 32;
  auto constexpr kCols = 16;

  Context ctx(MakeCUDACtx(0));
  auto dmat = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(true);
  RegTree expected_tree = GetApproxTree(&ctx, dmat.get());

  this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, true);
  this->DoTest([&] { VerifyApproxColumnSplit(kRows, kCols, expected_tree); }, false);
}
}  // namespace xgboost::tree