File: test_sycl_quantile_hist_builder.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (55 lines) | stat: -rw-r--r-- 1,814 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
/**
 * Copyright 2020-2024 by XGBoost contributors
 */
#include <gtest/gtest.h>

#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wtautological-constant-compare"
#pragma GCC diagnostic ignored "-W#pragma-messages"
#include <xgboost/json.h>
#include <xgboost/task.h>
#include "../../../plugin/sycl/tree/updater_quantile_hist.h"       // for QuantileHistMaker
#pragma GCC diagnostic pop

namespace xgboost::sycl::tree {
TEST(SyclQuantileHistMaker, Basic) {
  Context ctx;
  ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

  ObjInfo task{ObjInfo::kRegression};
  std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};

  ASSERT_EQ(updater->Name(), "grow_quantile_histmaker_sycl");
}

TEST(SyclQuantileHistMaker, JsonIO) {
  Context ctx;
  ctx.UpdateAllowUnknown(Args{{"device", "sycl"}});

  ObjInfo task{ObjInfo::kRegression};
  Json config {Object()};
  {
    std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
    updater->Configure({{"max_depth", std::to_string(42)}});
    updater->Configure({{"single_precision_histogram", std::to_string(true)}});
    updater->SaveConfig(&config);
  }

  {
    std::unique_ptr<TreeUpdater> updater{TreeUpdater::Create("grow_quantile_histmaker_sycl", &ctx, &task)};
    updater->LoadConfig(config);

    Json new_config {Object()};
    updater->SaveConfig(&new_config);

    ASSERT_EQ(config, new_config);

    auto max_depth = atoi(get<String const>(new_config["train_param"]["max_depth"]).c_str());
    ASSERT_EQ(max_depth, 42);

    auto single_precision_histogram = atoi(get<String const>(new_config["sycl_hist_train_param"]["single_precision_histogram"]).c_str());
    ASSERT_EQ(single_precision_histogram, 1);
  }
  
}
}  // namespace xgboost::sycl::tree