File: update_process.py

package info (click to toggle)
xgboost 3.0.4-1
  • links: PTS, VCS
  • area: main
  • in suites: sid
  • size: 13,848 kB
  • sloc: cpp: 67,603; python: 35,537; java: 4,676; ansic: 1,426; sh: 1,352; xml: 1,226; makefile: 204; javascript: 19
file content (95 lines) | stat: -rw-r--r-- 3,247 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""
Demo for using `process_type` with `prune` and `refresh`
========================================================

Modifying existing trees is not a well established use for XGBoost, so feel free to
experiment.

"""

import numpy as np
from sklearn.datasets import fetch_california_housing

import xgboost as xgb


def main():
    n_rounds = 32

    X, y = fetch_california_housing(return_X_y=True)

    # Train a model first
    X_train = X[: X.shape[0] // 2]
    y_train = y[: y.shape[0] // 2]
    Xy = xgb.DMatrix(X_train, y_train)
    evals_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
    booster = xgb.train(
        {"tree_method": "hist", "max_depth": 6, "device": "cuda"},
        Xy,
        num_boost_round=n_rounds,
        evals=[(Xy, "Train")],
        evals_result=evals_result,
    )
    SHAP = booster.predict(Xy, pred_contribs=True)

    # Refresh the leaf value and tree statistic
    X_refresh = X[X.shape[0] // 2 :]
    y_refresh = y[y.shape[0] // 2 :]
    Xy_refresh = xgb.DMatrix(X_refresh, y_refresh)
    # The model will adapt to other half of the data by changing leaf value (no change in
    # split condition) with refresh_leaf set to True.
    refresh_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
    refreshed = xgb.train(
        {"process_type": "update", "updater": "refresh", "refresh_leaf": True},
        Xy_refresh,
        num_boost_round=n_rounds,
        xgb_model=booster,
        evals=[(Xy, "Original"), (Xy_refresh, "Train")],
        evals_result=refresh_result,
    )

    # Refresh the model without changing the leaf value, but tree statistic including
    # cover and weight are refreshed.
    refresh_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
    refreshed = xgb.train(
        {"process_type": "update", "updater": "refresh", "refresh_leaf": False},
        Xy_refresh,
        num_boost_round=n_rounds,
        xgb_model=booster,
        evals=[(Xy, "Original"), (Xy_refresh, "Train")],
        evals_result=refresh_result,
    )
    # Without refreshing the leaf value, resulting trees should be the same with original
    # model except for accumulated statistic.  The rtol is for floating point error in
    # prediction.
    np.testing.assert_allclose(
        refresh_result["Original"]["rmse"], evals_result["Train"]["rmse"], rtol=1e-5
    )
    # But SHAP value is changed as cover in tree nodes are changed.
    refreshed_SHAP = refreshed.predict(Xy, pred_contribs=True)
    assert not np.allclose(SHAP, refreshed_SHAP, rtol=1e-3)

    # Prune the trees with smaller max_depth
    X_update = X_train
    y_update = y_train
    Xy_update = xgb.DMatrix(X_update, y_update)

    prune_result: xgb.callback.EvaluationMonitor.EvalsLog = {}
    pruned = xgb.train(
        {"process_type": "update", "updater": "prune", "max_depth": 2},
        Xy_update,
        num_boost_round=n_rounds,
        xgb_model=booster,
        evals=[(Xy, "Original"), (Xy_update, "Train")],
        evals_result=prune_result,
    )
    # Have a smaller model, but similar accuracy.
    np.testing.assert_allclose(
        np.array(prune_result["Original"]["rmse"]),
        np.array(prune_result["Train"]["rmse"]),
        atol=1e-5,
    )


if __name__ == "__main__":
    main()