File: test_survival_util.cc

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (45 lines) | stat: -rw-r--r-- 1,958 bytes parent folder | download | duplicates (4)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
/*!
 * Copyright (c) by Contributors 2020
 */
#include <gtest/gtest.h>

#include "../../../src/common/survival_util.h"

namespace xgboost {
namespace common {

template <typename Distribution>
inline static void RobustTestSuite(double y_lower, double y_upper, double sigma) {
  for (int i = 50; i >= -50; --i) {
    const double y_pred = std::pow(10.0, static_cast<double>(i));
    const double z = (std::log(y_lower) - std::log(y_pred)) / sigma;
    const double gradient
      = AFTLoss<Distribution>::Gradient(y_lower, y_upper, std::log(y_pred), sigma);
    const double hessian
      = AFTLoss<Distribution>::Hessian(y_lower, y_upper, std::log(y_pred), sigma);
    ASSERT_FALSE(std::isnan(gradient)) << "z = " << z << ", y \\in ["
      << y_lower << ", " << y_upper << "], y_pred = " << y_pred
      << ", dist = " << static_cast<int>(Distribution::Type());
    ASSERT_FALSE(std::isinf(gradient)) << "z = " << z << ", y \\in ["
      << y_lower << ", " << y_upper << "], y_pred = " << y_pred
      << ", dist = " << static_cast<int>(Distribution::Type());
    ASSERT_FALSE(std::isnan(hessian)) << "z = " << z << ", y \\in ["
      << y_lower << ", " << y_upper << "], y_pred = " << y_pred
      << ", dist = " << static_cast<int>(Distribution::Type());
    ASSERT_FALSE(std::isinf(hessian)) << "z = " << z << ", y \\in ["
      << y_lower << ", " << y_upper << "], y_pred = " << y_pred
      << ", dist = " << static_cast<int>(Distribution::Type());
  }
}

TEST(AFTLoss, RobustGradientPair) {  // Ensure that INF and NAN don't show up in gradient pair
  RobustTestSuite<NormalDistribution>(16.0, 200.0, 2.0);
  RobustTestSuite<LogisticDistribution>(16.0, 200.0, 2.0);
  RobustTestSuite<ExtremeDistribution>(16.0, 200.0, 2.0);
  RobustTestSuite<NormalDistribution>(100.0, 100.0, 2.0);
  RobustTestSuite<LogisticDistribution>(100.0, 100.0, 2.0);
  RobustTestSuite<ExtremeDistribution>(100.0, 100.0, 2.0);
}

}  // namespace common
}  // namespace xgboost