File: custom_metric_obj.rst

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (324 lines) | stat: -rw-r--r-- 12,497 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
######################################
Custom Objective and Evaluation Metric
######################################

**Contents**

.. contents::
  :backlinks: none
  :local:

********
Overview
********

XGBoost is designed to be an extensible library.  One way to extend it is by providing our
own objective function for training and corresponding metric for performance monitoring.
This document introduces implementing a customized elementwise evaluation metric and
objective for XGBoost. Although the introduction uses Python for demonstration, the
concepts should be readily applicable to other language bindings.

.. note::

   * The ranking task does not support customized functions.
   * Breaking change was made in XGBoost 1.6.

See also the advanced usage example for more information about limitations and
workarounds for more complex objetives: :doc:`/tutorials/advanced_custom_obj`

In the following two sections, we will provide a step by step walk through of implementing
the ``Squared Log Error (SLE)`` objective function:

.. math::
   \frac{1}{2}[\log(pred + 1) - \log(label + 1)]^2

and its default metric ``Root Mean Squared Log Error(RMSLE)``:

.. math::
   \sqrt{\frac{1}{N}[\log(pred + 1) - \log(label + 1)]^2}

Although XGBoost has native support for said functions, using it for demonstration
provides us the opportunity of comparing the result from our own implementation and the
one from XGBoost internal for learning purposes.  After finishing this tutorial, we should
be able to provide our own functions for rapid experiments.  And at the end, we will
provide some notes on non-identity link function along with examples of using custom metric
and objective with the `scikit-learn` interface.

If we compute the gradient of said objective function:

.. math::
   g = \frac{\partial{objective}}{\partial{pred}} = \frac{\log(pred + 1) - \log(label + 1)}{pred + 1}

As well as the hessian (the second derivative of the objective):

.. math::
   h = \frac{\partial^2{objective}}{\partial{pred}^2} = \frac{ - \log(pred + 1) + \log(label + 1) + 1}{(pred + 1)^2}

*****************************
Customized Objective Function
*****************************

During model training, the objective function plays an important role: provide gradient
information, both first and second order gradient, based on model predictions and observed
data labels (or targets).  Therefore, a valid objective function should accept two inputs,
namely prediction and labels.  For implementing ``SLE``, we define:

.. code-block:: python

    import numpy as np
    import xgboost as xgb
    from typing import Tuple

    def gradient(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
        '''Compute the gradient squared log error.'''
        y = dtrain.get_label()
        return (np.log1p(predt) - np.log1p(y)) / (predt + 1)

    def hessian(predt: np.ndarray, dtrain: xgb.DMatrix) -> np.ndarray:
        '''Compute the hessian for squared log error.'''
        y = dtrain.get_label()
        return ((-np.log1p(predt) + np.log1p(y) + 1) /
                np.power(predt + 1, 2))

    def squared_log(predt: np.ndarray,
                    dtrain: xgb.DMatrix) -> Tuple[np.ndarray, np.ndarray]:
        '''Squared Log Error objective. A simplified version for RMSLE used as
        objective function.
        '''
        predt[predt < -1] = -1 + 1e-6
        grad = gradient(predt, dtrain)
        hess = hessian(predt, dtrain)
        return grad, hess


In the above code snippet, ``squared_log`` is the objective function we want.  It accepts a
numpy array ``predt`` as model prediction, and the training DMatrix for obtaining required
information, including labels and weights (not used here).  This objective is then used as
a callback function for XGBoost during training by passing it as an argument to
``xgb.train``:

.. code-block:: python

   xgb.train({'tree_method': 'hist', 'seed': 1994},  # any other tree method is fine.
              dtrain=dtrain,
              num_boost_round=10,
              obj=squared_log)

Notice that in our definition of the objective, whether we subtract the labels from the
prediction or the other way around is important.  If you find the training error goes up
instead of down, this might be the reason.


**************************
Customized Metric Function
**************************

So after having a customized objective, we might also need a corresponding metric to
monitor our model's performance.  As mentioned above, the default metric for ``SLE`` is
``RMSLE``.  Similarly we define another callback like function as the new metric:

.. code-block:: python

    def rmsle(predt: np.ndarray, dtrain: xgb.DMatrix) -> Tuple[str, float]:
        ''' Root mean squared log error metric.'''
        y = dtrain.get_label()
        predt[predt < -1] = -1 + 1e-6
        elements = np.power(np.log1p(y) - np.log1p(predt), 2)
        return 'PyRMSLE', float(np.sqrt(np.sum(elements) / len(y)))

Since we are demonstrating in Python, the metric or objective need not be a function, any
callable object should suffice.  Similar to the objective function, our metric also
accepts ``predt`` and ``dtrain`` as inputs, but returns the name of the metric itself and
a floating point value as the result.  After passing it into XGBoost as argument of
``custom_metric`` parameter:

.. code-block:: python

    xgb.train({'tree_method': 'hist', 'seed': 1994,
               'disable_default_eval_metric': 1},
              dtrain=dtrain,
              num_boost_round=10,
              obj=squared_log,
              custom_metric=rmsle,
              evals=[(dtrain, 'dtrain'), (dtest, 'dtest')],
              evals_result=results)

We will be able to see XGBoost printing something like:

.. code-block:: none

    [0] dtrain-PyRMSLE:1.37153  dtest-PyRMSLE:1.31487
    [1] dtrain-PyRMSLE:1.26619  dtest-PyRMSLE:1.20899
    [2] dtrain-PyRMSLE:1.17508  dtest-PyRMSLE:1.11629
    [3] dtrain-PyRMSLE:1.09836  dtest-PyRMSLE:1.03871
    [4] dtrain-PyRMSLE:1.03557  dtest-PyRMSLE:0.977186
    [5] dtrain-PyRMSLE:0.985783 dtest-PyRMSLE:0.93057
    ...

Notice that the parameter ``disable_default_eval_metric`` is used to suppress the default metric
in XGBoost.

For fully reproducible source code and comparison plots, see
:ref:`sphx_glr_python_examples_custom_rmsle.py`.

*********************
Reverse Link Function
*********************

When using builtin objective, the raw prediction is transformed according to the objective
function.  When a custom objective is provided XGBoost doesn't know its link function so the
user is responsible for making the transformation for both objective and custom evaluation
metric.  For objective with identity link like ``squared error`` this is trivial, but for
other link functions like log link or inverse link the difference is significant.

For the Python package, the behaviour of prediction can be controlled by the
``output_margin`` parameter in ``predict`` function.  When using the ``custom_metric``
parameter without a custom objective, the metric function will receive transformed
prediction since the objective is defined by XGBoost. However, when the custom objective is
also provided along with that metric, then both the objective and custom metric will
receive raw prediction.  The following example provides a comparison between two different
behavior with a multi-class classification model. Firstly we define 2 different Python
metric functions implementing the same underlying metric for comparison,
`merror_with_transform` is used when custom objective is also used, otherwise the simpler
`merror` is preferred since XGBoost can perform the transformation itself.

.. code-block:: python

    import xgboost as xgb
    import numpy as np

    def merror_with_transform(predt: np.ndarray, dtrain: xgb.DMatrix):
        """Used when custom objective is supplied."""
        y = dtrain.get_label()
        n_classes = predt.size // y.shape[0]
        # Like custom objective, the predt is untransformed leaf weight when custom objective
        # is provided.

        # With the use of `custom_metric` parameter in train function, custom metric receives
        # raw input only when custom objective is also being used.  Otherwise custom metric
        # will receive transformed prediction.
        assert predt.shape == (d_train.num_row(), n_classes)
        out = np.zeros(dtrain.num_row())
        for r in range(predt.shape[0]):
            i = np.argmax(predt[r])
            out[r] = i

        assert y.shape == out.shape

        errors = np.zeros(dtrain.num_row())
        errors[y != out] = 1.0
        return 'PyMError', np.sum(errors) / dtrain.num_row()

The above function is only needed when we want to use custom objective and XGBoost doesn't
know how to transform the prediction.  The normal implementation for multi-class error
function is:

.. code-block:: python

    def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
        """Used when there's no custom objective."""
        # No need to do transform, XGBoost handles it internally.
        errors = np.zeros(dtrain.num_row())
        errors[y != out] = 1.0
        return 'PyMError', np.sum(errors) / dtrain.num_row()


Next we need the custom softprob objective:

.. code-block:: python

    def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
        """Loss function.  Computing the gradient and approximated hessian (diagonal).
        Reimplements the `multi:softprob` inside XGBoost.
        """

        # Full implementation is available in the Python demo script linked below
        ...

        return grad, hess

Lastly we can train the model using ``obj`` and ``custom_metric`` parameters:

.. code-block:: python

    Xy = xgb.DMatrix(X, y)
    booster = xgb.train(
        {"num_class": kClasses, "disable_default_eval_metric": True},
        m,
        num_boost_round=kRounds,
        obj=softprob_obj,
        custom_metric=merror_with_transform,
        evals_result=custom_results,
        evals=[(m, "train")],
    )

Or if you don't need the custom objective and just want to supply a metric that's not
available in XGBoost:

.. code-block:: python

    booster = xgb.train(
        {
            "num_class": kClasses,
            "disable_default_eval_metric": True,
            "objective": "multi:softmax",
        },
        m,
        num_boost_round=kRounds,
        # Use a simpler metric implementation.
        custom_metric=merror,
        evals_result=custom_results,
        evals=[(m, "train")],
    )

We use ``multi:softmax`` to illustrate the differences of transformed prediction.  With
``softprob`` the output prediction array has shape ``(n_samples, n_classes)`` while for
``softmax`` it's ``(n_samples, )``. A demo for multi-class objective function is also
available at :ref:`sphx_glr_python_examples_custom_softmax.py`. Also, see
:doc:`/tutorials/intercept` for some more explanation.


**********************
Scikit-Learn Interface
**********************

The scikit-learn interface of XGBoost has some utilities to improve the integration with
standard scikit-learn functions.  For instance, after XGBoost 1.6.0 users can use the cost
function (not scoring functions) from scikit-learn out of the box:

.. code-block:: python

    from sklearn.datasets import load_diabetes
    from sklearn.metrics import mean_absolute_error
    X, y = load_diabetes(return_X_y=True)
    reg = xgb.XGBRegressor(
        tree_method="hist",
        eval_metric=mean_absolute_error,
    )
    reg.fit(X, y, eval_set=[(X, y)])

Also, for custom objective function, users can define the objective without having to
access ``DMatrix``:

.. code-block:: python

    def softprob_obj(labels: np.ndarray, predt: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
        rows = labels.shape[0]
        classes = predt.shape[1]
        grad = np.zeros((rows, classes), dtype=float)
        hess = np.zeros((rows, classes), dtype=float)
        eps = 1e-6
        for r in range(predt.shape[0]):
            target = labels[r]
            p = softmax(predt[r, :])
            for c in range(predt.shape[1]):
                g = p[c] - 1.0 if c == target else p[c]
                h = max((2.0 * p[c] * (1.0 - p[c])).item(), eps)
                grad[r, c] = g
                hess[r, c] = h

        grad = grad.reshape((rows * classes, 1))
        hess = hess.reshape((rows * classes, 1))
        return grad, hess

    clf = xgb.XGBClassifier(tree_method="hist", objective=softprob_obj)