File: test_quantile_obj.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (73 lines) | stat: -rw-r--r-- 2,339 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
/**
 * Copyright 2017-2024 by XGBoost contributors
 */
#include <xgboost/base.h>       // Args
#include <xgboost/context.h>    // Context
#include <xgboost/objective.h>  // ObjFunction
#include <xgboost/span.h>       // Span

#include <memory>               // std::unique_ptr
#include <vector>               // std::vector

#include "../helpers.h"         // CheckConfigReload,MakeCUDACtx,DeclareUnifiedTest

#include "test_quantile_obj.h"

namespace xgboost {

void TestQuantile(const Context* ctx) {
{
    Args args{{"quantile_alpha", "[0.6, 0.8]"}};
    std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
    obj->Configure(args);
    CheckConfigReload(obj, "reg:quantileerror");
  }

  Args args{{"quantile_alpha", "0.6"}};
  std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
  obj->Configure(args);
  CheckConfigReload(obj, "reg:quantileerror");

  std::vector<float> predts{1.0f, 2.0f, 3.0f};
  std::vector<float> labels{3.0f, 2.0f, 1.0f};
  std::vector<float> weights{1.0f, 1.0f, 1.0f};
  std::vector<float> grad{-0.6f, 0.4f, 0.4f};
  std::vector<float> hess = weights;
  CheckObjFunction(obj, predts, labels, weights, grad, hess);
}

void TestQuantileIntercept(const Context* ctx) {
  Args args{{"quantile_alpha", "[0.6, 0.8]"}};
  std::unique_ptr<ObjFunction> obj{ObjFunction::Create("reg:quantileerror", ctx)};
  obj->Configure(args);

  MetaInfo info;
  info.num_row_ = 10;
  info.labels.ModifyInplace([&](HostDeviceVector<float>* data, common::Span<std::size_t> shape) {
    data->SetDevice(ctx->Device());
    data->Resize(info.num_row_);
    shape[0] = info.num_row_;
    shape[1] = 1;

    auto& h_labels = data->HostVector();
    for (std::size_t i = 0; i < info.num_row_; ++i) {
      h_labels[i] = i;
    }
  });

  linalg::Vector<float> base_scores;
  obj->InitEstimation(info, &base_scores);
  ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
  // mean([5.6, 7.8])
  ASSERT_NEAR(base_scores(0), 6.7, kRtEps);

  for (std::size_t i = 0; i < info.num_row_; ++i) {
    info.weights_.HostVector().emplace_back(info.num_row_ - i - 1.0);
  }

  obj->InitEstimation(info, &base_scores);
  ASSERT_EQ(base_scores.Size(), 1) << "Vector is not yet supported.";
  // mean([3, 5])
  ASSERT_NEAR(base_scores(0), 4.0, kRtEps);
}
}  // namespace xgboost