File: test_with_arrow.py

package info (click to toggle)
xgboost 3.0.0-1
  • links: PTS, VCS
  • area: main
  • in suites: trixie
  • size: 13,796 kB
  • sloc: cpp: 67,502; python: 35,503; java: 4,676; ansic: 1,426; sh: 1,320; xml: 1,197; makefile: 204; javascript: 19
file content (109 lines) | stat: -rw-r--r-- 3,880 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os

import numpy as np
import pytest

import xgboost as xgb
from xgboost import testing as tm
from xgboost.core import DataSplitMode

pytestmark = pytest.mark.skipif(
    tm.no_arrow()["condition"] or tm.no_pandas()["condition"],
    reason=tm.no_arrow()["reason"] + " or " + tm.no_pandas()["reason"],
)

import pandas as pd
import pyarrow as pa
import pyarrow.csv as pc


class TestArrowTable:
    def test_arrow_table(self):
        df = pd.DataFrame(
            [[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
        )
        table = pa.Table.from_pandas(df)
        dm = xgb.DMatrix(table)
        assert dm.num_row() == 2
        assert dm.num_col() == 4

    def test_arrow_table_with_label(self):
        df = pd.DataFrame([[1, 2.0, 3.0], [2, 3.0, 4.0]], columns=["a", "b", "c"])
        table = pa.Table.from_pandas(df)
        label = np.array([0, 1])
        dm = xgb.DMatrix(table)
        dm.set_label(label)
        assert dm.num_row() == 2
        assert dm.num_col() == 3
        np.testing.assert_array_equal(dm.get_label(), np.array([0, 1]))

    def test_arrow_table_from_np(self):
        coldata = np.array(
            [[1.0, 1.0, 0.0, 0.0], [2.0, 0.0, 1.0, 0.0], [3.0, 0.0, 0.0, 1.0]]
        )
        cols = list(map(pa.array, coldata))
        table = pa.Table.from_arrays(cols, ["a", "b", "c"])
        dm = xgb.DMatrix(table)
        assert dm.num_row() == 4
        assert dm.num_col() == 3

    @pytest.mark.parametrize("DMatrixT", [xgb.DMatrix, xgb.QuantileDMatrix])
    def test_arrow_train(self, DMatrixT):
        import pandas as pd

        rows = 100
        X = pd.DataFrame(
            {
                "A": np.random.randint(0, 10, size=rows),
                "B": np.random.randn(rows),
                "C": np.random.permutation([1, 0] * (rows // 2)),
            }
        )
        y = pd.Series(np.random.randn(rows))

        table = pa.Table.from_pandas(X)
        dtrain1 = DMatrixT(table)
        dtrain1.set_label(pa.Table.from_pandas(pd.DataFrame(y)))
        bst1 = xgb.train({}, dtrain1, num_boost_round=10)
        preds1 = bst1.predict(DMatrixT(X))

        dtrain2 = DMatrixT(X, y)
        bst2 = xgb.train({}, dtrain2, num_boost_round=10)
        preds2 = bst2.predict(DMatrixT(X))

        np.testing.assert_allclose(preds1, preds2)

        preds3 = bst2.inplace_predict(table)
        np.testing.assert_allclose(preds1, preds3)
        assert bst2.feature_names == ["A", "B", "C"]
        assert bst2.feature_types == ["int", "float", "int"]

    def test_arrow_survival(self):
        data = os.path.join(tm.data_dir(__file__), "veterans_lung_cancer.csv")
        table = pc.read_csv(data)
        y_lower_bound = table["Survival_label_lower_bound"]
        y_upper_bound = table["Survival_label_upper_bound"]
        X = table.drop(["Survival_label_lower_bound", "Survival_label_upper_bound"])

        dtrain = xgb.DMatrix(
            X, label_lower_bound=y_lower_bound, label_upper_bound=y_upper_bound
        )
        y_np_up = dtrain.get_float_info("label_upper_bound")
        y_np_low = dtrain.get_float_info("label_lower_bound")
        np.testing.assert_equal(y_np_up, y_upper_bound.to_pandas().values)
        np.testing.assert_equal(y_np_low, y_lower_bound.to_pandas().values)


@pytest.mark.skipif(tm.is_windows(), reason="Rabit does not run on windows")
class TestArrowTableColumnSplit:
    def test_arrow_table(self):
        def verify_arrow_table():
            df = pd.DataFrame(
                [[0, 1, 2.0, 3.0], [1, 2, 3.0, 4.0]], columns=["a", "b", "c", "d"]
            )
            table = pa.Table.from_pandas(df)
            dm = xgb.DMatrix(table, data_split_mode=DataSplitMode.COL)
            assert dm.num_row() == 2
            assert dm.num_col() == 4 * xgb.collective.get_world_size()

        tm.run_with_rabit(world_size=3, test_fn=verify_arrow_table)