File: test_hinge.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (56 lines) | stat: -rw-r--r-- 1,948 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
/**
 * Copyright 2018-2023, XGBoost Contributors
 */
#include <xgboost/objective.h>
#include <xgboost/context.h>
#include <limits>

#include "../helpers.h"
#include "test_hinge.h"
#include "../../../src/common/linalg_op.h"
namespace xgboost {

void TestHingeObj(const Context* ctx) {
  std::unique_ptr<ObjFunction> obj{ObjFunction::Create("binary:hinge", ctx)};

  float eps = std::numeric_limits<xgboost::bst_float>::min();
  std::vector<float> predt{-1.0f, -0.5f, 0.5f, 1.0f, -1.0f, -0.5f, 0.5f, 1.0f};
  std::vector<float> label{ 0.0f,  0.0f, 0.0f, 0.0f,  1.0f,  1.0f,  1.0f, 1.0f};
  std::vector<float> grad{0.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 0.0f};
  std::vector<float> hess{eps, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, eps};

  CheckObjFunction(obj, predt, label, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, grad, hess);
  CheckObjFunction(obj, predt, label, {/* Empty weight. */}, grad, hess);

  ASSERT_EQ(obj->DefaultEvalMetric(), StringView{"error"});

  MetaInfo info;
  info.num_row_ = label.size();
  info.labels.Reshape(info.num_row_, 3);
  ASSERT_EQ(obj->Targets(info), 3);
  auto h_labels = info.labels.HostView();
  for (std::size_t j = 0; j < obj->Targets(info); ++j) {
    for (std::size_t i = 0; i < info.num_row_; ++i) {
      h_labels(i, j) = label[i];
    }
  }
  linalg::Tensor<float, 2> t_predt{};
  t_predt.Reshape(info.labels.Shape());
  for (std::size_t j = 0; j < obj->Targets(info); ++j) {
    for (std::size_t i = 0; i < info.num_row_; ++i) {
      t_predt(i, j) = predt[i];
    }
  }
  linalg::Matrix<GradientPair> out_gpair;
  obj->GetGradient(*t_predt.Data(), info, 0, &out_gpair);

  for (std::size_t j = 0; j < obj->Targets(info); ++j) {
    auto gh = out_gpair.Slice(linalg::All(), j);
    ASSERT_EQ(gh.Size(), info.num_row_);
    for (std::size_t i = 0; i < gh.Size(); ++i) {
      ASSERT_EQ(gh(i).GetGrad(), grad[i]);
      ASSERT_EQ(gh(i).GetHess(), hess[i]);
    }
  }
}
}  // namespace xgboost