File: callbacks.rst

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (61 lines) | stat: -rw-r--r-- 2,430 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
##################
Callback Functions
##################

This document gives a basic walkthrough of :ref:`callback API <callback_api>` used in
XGBoost Python package.  In XGBoost 1.3, a new callback interface is designed for Python
package, which provides the flexibility of designing various extension for training.
Also, XGBoost has a number of pre-defined callbacks for supporting early stopping,
checkpoints etc.


Using builtin callbacks
-----------------------

By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
corresponding callbacks internally.  For example, when ``early_stopping_rounds`` is
specified, :py:class:`EarlyStopping <xgboost.callback.EarlyStopping>` callback is invoked
inside iteration loop.  You can also pass this callback function directly into XGBoost:

.. code-block:: python

    D_train = xgb.DMatrix(X_train, y_train)
    D_valid = xgb.DMatrix(X_valid, y_valid)

    # Define a custom evaluation metric used for early stopping.
    def eval_error_metric(predt, dtrain: xgb.DMatrix):
        label = dtrain.get_label()
        r = np.zeros(predt.shape)
        gt = predt > 0.5
        r[gt] = 1 - label[gt]
        le = predt <= 0.5
        r[le] = label[le]
        return 'CustomErr', np.sum(r)

    # Specify which dataset and which metric should be used for early stopping.
    early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
                                            metric_name='CustomErr',
                                            data_name='Valid')

    booster = xgb.train(
        {'objective': 'binary:logistic',
         'eval_metric': ['error', 'rmse'],
         'tree_method': 'hist'}, D_train,
        evals=[(D_train, 'Train'), (D_valid, 'Valid')],
        feval=eval_error_metric,
        num_boost_round=1000,
        callbacks=[early_stop],
        verbose_eval=False)

    dump = booster.get_dump(dump_format='json')
    assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)


Defining your own callback
--------------------------

XGBoost provides an callback interface class: :py:class:`TrainingCallback
<xgboost.callback.TrainingCallback>`, user defined callbacks should inherit this class and
override corresponding methods.  There's a working example in
:ref:`sphx_glr_python_examples_callbacks.py`.