File: test_with_polars.py

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (145 lines) | stat: -rw-r--r-- 4,073 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""Copyright 2024, XGBoost contributors"""

import json
import os
import tempfile
from typing import Type, Union

import numpy as np
import pytest

import xgboost as xgb

pl = pytest.importorskip("polars")


@pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
def test_polars_basic(
    DMatrixT: Union[Type[xgb.DMatrix], Type[xgb.QuantileDMatrix]]
) -> None:
    df = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5]})
    Xy = DMatrixT(df)
    assert Xy.num_row() == df.shape[0]
    assert Xy.num_col() == df.shape[1]
    assert Xy.num_nonmissing() == np.prod(df.shape)

    # feature info
    assert Xy.feature_names == df.columns
    assert Xy.feature_types == ["int", "int"]

    res = Xy.get_data().toarray()
    res1 = df.to_numpy()

    if isinstance(Xy, xgb.QuantileDMatrix):
        # skip min values in the cut.
        np.testing.assert_allclose(res[1:, :], res1[1:, :])
    else:
        np.testing.assert_allclose(res, res1)

    # boolean
    df = pl.DataFrame({"a": [True, False, False], "b": [False, False, True]})
    Xy = DMatrixT(df)
    np.testing.assert_allclose(
        Xy.get_data().data, np.array([1, 0, 0, 0, 0, 1]), atol=1e-5
    )


def test_polars_missing() -> None:
    df = pl.DataFrame({"a": [1, None, 3], "b": [3, 4, None]})
    Xy = xgb.DMatrix(df)
    assert Xy.num_row() == df.shape[0]
    assert Xy.num_col() == df.shape[1]
    assert Xy.num_nonmissing() == 4

    np.testing.assert_allclose(Xy.get_data().data, np.array([1, 3, 4, 3]))
    np.testing.assert_allclose(Xy.get_data().indptr, np.array([0, 2, 3, 4]))
    np.testing.assert_allclose(Xy.get_data().indices, np.array([0, 1, 1, 0]))

    ser = pl.Series("y", np.arange(0, df.shape[0]))
    Xy.set_info(label=ser)
    booster = xgb.train({}, Xy, num_boost_round=1)
    predt0 = booster.inplace_predict(df)
    predt1 = booster.predict(Xy)
    np.testing.assert_allclose(predt0, predt1)


def test_classififer() -> None:
    from sklearn.datasets import make_classification, make_multilabel_classification

    X, y = make_classification(random_state=2024)
    X_df = pl.DataFrame(X)
    y_ser = pl.Series(y)

    clf0 = xgb.XGBClassifier()
    clf0.fit(X_df, y_ser)

    clf1 = xgb.XGBClassifier()
    clf1.fit(X, y)

    with tempfile.TemporaryDirectory() as tmpdir:
        path0 = os.path.join(tmpdir, "clf0.json")
        clf0.save_model(path0)

        path1 = os.path.join(tmpdir, "clf1.json")
        clf1.save_model(path1)

        with open(path0, "r") as fd:
            model0 = json.load(fd)
        with open(path1, "r") as fd:
            model1 = json.load(fd)

    model0["learner"]["feature_names"] = []
    model0["learner"]["feature_types"] = []
    assert model0 == model1

    predt0 = clf0.predict(X)
    predt1 = clf1.predict(X)

    np.testing.assert_allclose(predt0, predt1)

    assert (clf0.feature_names_in_ == X_df.columns).all()
    assert clf0.n_features_in_ == X_df.shape[1]

    X, y = make_multilabel_classification(128)
    X_df = pl.DataFrame(X)
    y_df = pl.DataFrame(y)
    clf = xgb.XGBClassifier(n_estimators=1)
    clf.fit(X_df, y_df)
    assert clf.n_classes_ == 2

    X, y = make_classification(n_classes=3, n_informative=5)
    X_df = pl.DataFrame(X)
    y_ser = pl.Series(y)
    clf = xgb.XGBClassifier(n_estimators=1)
    clf.fit(X_df, y_ser)
    assert clf.n_classes_ == 3


def test_regressor() -> None:
    from sklearn.datasets import make_regression

    X, y = make_regression(n_targets=3)
    X_df = pl.DataFrame(X)
    y_df = pl.DataFrame(y)
    assert y_df.shape[1] == 3

    reg0 = xgb.XGBRegressor()
    reg0.fit(X_df, y_df)

    reg1 = xgb.XGBRegressor()
    reg1.fit(X, y)

    predt0 = reg0.predict(X)
    predt1 = reg1.predict(X)

    np.testing.assert_allclose(predt0, predt1)

def test_categorical() ->  None:
    import polars as pl

    df = pl.DataFrame(
        {"f0": [1, 2, 3], "b": ["a", "b", "c"]},
        schema=[("a", pl.Int64()), ("b", pl.Categorical())]
    )
    with pytest.raises(NotImplementedError, match="Categorical feature"):
        xgb.DMatrix(df, enable_categorical=True)