File: test_array_api.py

package info (click to toggle)
scikit-learn 1.2.1%2Bdfsg-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 23,280 kB
  • sloc: python: 184,491; cpp: 5,783; ansic: 854; makefile: 307; sh: 45; javascript: 1
file content (189 lines) | stat: -rw-r--r-- 6,709 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import numpy
from numpy.testing import assert_array_equal
import pytest

from sklearn.base import BaseEstimator
from sklearn.utils._array_api import get_namespace
from sklearn.utils._array_api import _NumPyApiWrapper
from sklearn.utils._array_api import _ArrayAPIWrapper
from sklearn.utils._array_api import _asarray_with_order
from sklearn.utils._array_api import _convert_to_numpy
from sklearn.utils._array_api import _estimator_with_converted_arrays
from sklearn._config import config_context

pytestmark = pytest.mark.filterwarnings(
    "ignore:The numpy.array_api submodule:UserWarning"
)


def test_get_namespace_ndarray():
    """Test get_namespace on NumPy ndarrays."""
    pytest.importorskip("numpy.array_api")

    X_np = numpy.asarray([[1, 2, 3]])

    # Dispatching on Numpy regardless or the value of array_api_dispatch.
    for array_api_dispatch in [True, False]:
        with config_context(array_api_dispatch=array_api_dispatch):
            xp_out, is_array_api = get_namespace(X_np)
            assert not is_array_api
            assert isinstance(xp_out, _NumPyApiWrapper)


def test_get_namespace_array_api():
    """Test get_namespace for ArrayAPI arrays."""
    xp = pytest.importorskip("numpy.array_api")

    X_np = numpy.asarray([[1, 2, 3]])
    X_xp = xp.asarray(X_np)
    with config_context(array_api_dispatch=True):
        xp_out, is_array_api = get_namespace(X_xp)
        assert is_array_api
        assert isinstance(xp_out, _ArrayAPIWrapper)

        # check errors
        with pytest.raises(ValueError, match="Multiple namespaces"):
            get_namespace(X_np, X_xp)

        with pytest.raises(ValueError, match="Unrecognized array input"):
            get_namespace(1)


class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
    """API wrapper that has an adjustable name. Used for testing."""

    def __init__(self, array_namespace, name):
        super().__init__(array_namespace=array_namespace)
        self.__name__ = name


def test_array_api_wrapper_astype():
    """Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
    numpy_array_api = pytest.importorskip("numpy.array_api")
    xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
    xp = _ArrayAPIWrapper(xp_)

    X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
    X_converted = xp.astype(X, xp.float32)
    assert X_converted.dtype == xp.float32

    X_converted = xp.asarray(X, dtype=xp.float32)
    assert X_converted.dtype == xp.float32


def test_array_api_wrapper_take_for_numpy_api():
    """Test that fast path is called for numpy.array_api."""
    numpy_array_api = pytest.importorskip("numpy.array_api")
    # USe the same name as numpy.array_api
    xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api")
    xp = _ArrayAPIWrapper(xp_)

    X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
    X_take = xp.take(X, xp.asarray([1]), axis=0)
    assert hasattr(X_take, "__array_namespace__")
    assert_array_equal(X_take, numpy.take(X, [1], axis=0))


def test_array_api_wrapper_take():
    """Test _ArrayAPIWrapper API for take."""
    numpy_array_api = pytest.importorskip("numpy.array_api")
    xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
    xp = _ArrayAPIWrapper(xp_)

    # Check take compared to NumPy's with axis=0
    X_1d = xp.asarray([1, 2, 3], dtype=xp.float64)
    X_take = xp.take(X_1d, xp.asarray([1]), axis=0)
    assert hasattr(X_take, "__array_namespace__")
    assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0))

    X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
    X_take = xp.take(X, xp.asarray([0]), axis=0)
    assert hasattr(X_take, "__array_namespace__")
    assert_array_equal(X_take, numpy.take(X, [0], axis=0))

    # Check take compared to NumPy's with axis=1
    X_take = xp.take(X, xp.asarray([0, 2]), axis=1)
    assert hasattr(X_take, "__array_namespace__")
    assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1))

    with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"):
        xp.take(X, xp.asarray([0]), axis=2)

    with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"):
        xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0)


@pytest.mark.parametrize("is_array_api", [True, False])
def test_asarray_with_order(is_array_api):
    """Test _asarray_with_order passes along order for NumPy arrays."""
    if is_array_api:
        xp = pytest.importorskip("numpy.array_api")
    else:
        xp = numpy

    X = xp.asarray([1.2, 3.4, 5.1])
    X_new = _asarray_with_order(X, order="F")

    X_new_np = numpy.asarray(X_new)
    assert X_new_np.flags["F_CONTIGUOUS"]


def test_asarray_with_order_ignored():
    """Test _asarray_with_order ignores order for Generic ArrayAPI."""
    xp = pytest.importorskip("numpy.array_api")
    xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")

    X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C")
    X = xp_.asarray(X)

    X_new = _asarray_with_order(X, order="F", xp=xp_)

    X_new_np = numpy.asarray(X_new)
    assert X_new_np.flags["C_CONTIGUOUS"]
    assert not X_new_np.flags["F_CONTIGUOUS"]


def test_convert_to_numpy_error():
    """Test convert to numpy errors for unsupported namespaces."""
    xp = pytest.importorskip("numpy.array_api")
    xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")

    X = xp_.asarray([1.2, 3.4])

    with pytest.raises(ValueError, match="Supported namespaces are:"):
        _convert_to_numpy(X, xp=xp_)


class SimpleEstimator(BaseEstimator):
    def fit(self, X, y=None):
        self.X_ = X
        self.n_features_ = X.shape[0]
        return self


@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
def test_convert_estimator_to_ndarray(array_namespace):
    """Convert estimator attributes to ndarray."""
    xp = pytest.importorskip(array_namespace)

    if array_namespace == "numpy.array_api":
        converter = lambda array: numpy.asarray(array)  # noqa
    else:  # pragma: no cover
        converter = lambda array: array._array.get()  # noqa

    X = xp.asarray([[1.3, 4.5]])
    est = SimpleEstimator().fit(X)

    new_est = _estimator_with_converted_arrays(est, converter)
    assert isinstance(new_est.X_, numpy.ndarray)


def test_convert_estimator_to_array_api():
    """Convert estimator attributes to ArrayAPI arrays."""
    xp = pytest.importorskip("numpy.array_api")

    X_np = numpy.asarray([[1.3, 4.5]])
    est = SimpleEstimator().fit(X_np)

    new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
    assert hasattr(new_est.X_, "__array_namespace__")