1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
|
import numpy
from numpy.testing import assert_array_equal
import pytest
from sklearn.base import BaseEstimator
from sklearn.utils._array_api import get_namespace
from sklearn.utils._array_api import _NumPyApiWrapper
from sklearn.utils._array_api import _ArrayAPIWrapper
from sklearn.utils._array_api import _asarray_with_order
from sklearn.utils._array_api import _convert_to_numpy
from sklearn.utils._array_api import _estimator_with_converted_arrays
from sklearn._config import config_context
pytestmark = pytest.mark.filterwarnings(
"ignore:The numpy.array_api submodule:UserWarning"
)
def test_get_namespace_ndarray():
"""Test get_namespace on NumPy ndarrays."""
pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1, 2, 3]])
# Dispatching on Numpy regardless or the value of array_api_dispatch.
for array_api_dispatch in [True, False]:
with config_context(array_api_dispatch=array_api_dispatch):
xp_out, is_array_api = get_namespace(X_np)
assert not is_array_api
assert isinstance(xp_out, _NumPyApiWrapper)
def test_get_namespace_array_api():
"""Test get_namespace for ArrayAPI arrays."""
xp = pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1, 2, 3]])
X_xp = xp.asarray(X_np)
with config_context(array_api_dispatch=True):
xp_out, is_array_api = get_namespace(X_xp)
assert is_array_api
assert isinstance(xp_out, _ArrayAPIWrapper)
# check errors
with pytest.raises(ValueError, match="Multiple namespaces"):
get_namespace(X_np, X_xp)
with pytest.raises(ValueError, match="Unrecognized array input"):
get_namespace(1)
class _AdjustableNameAPITestWrapper(_ArrayAPIWrapper):
"""API wrapper that has an adjustable name. Used for testing."""
def __init__(self, array_namespace, name):
super().__init__(array_namespace=array_namespace)
self.__name__ = name
def test_array_api_wrapper_astype():
"""Test _ArrayAPIWrapper for ArrayAPIs that is not NumPy."""
numpy_array_api = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_converted = xp.astype(X, xp.float32)
assert X_converted.dtype == xp.float32
X_converted = xp.asarray(X, dtype=xp.float32)
assert X_converted.dtype == xp.float32
def test_array_api_wrapper_take_for_numpy_api():
"""Test that fast path is called for numpy.array_api."""
numpy_array_api = pytest.importorskip("numpy.array_api")
# USe the same name as numpy.array_api
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_take = xp.take(X, xp.asarray([1]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [1], axis=0))
def test_array_api_wrapper_take():
"""Test _ArrayAPIWrapper API for take."""
numpy_array_api = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(numpy_array_api, "wrapped_numpy.array_api")
xp = _ArrayAPIWrapper(xp_)
# Check take compared to NumPy's with axis=0
X_1d = xp.asarray([1, 2, 3], dtype=xp.float64)
X_take = xp.take(X_1d, xp.asarray([1]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X_1d, [1], axis=0))
X = xp.asarray(([[1, 2, 3], [3, 4, 5]]), dtype=xp.float64)
X_take = xp.take(X, xp.asarray([0]), axis=0)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [0], axis=0))
# Check take compared to NumPy's with axis=1
X_take = xp.take(X, xp.asarray([0, 2]), axis=1)
assert hasattr(X_take, "__array_namespace__")
assert_array_equal(X_take, numpy.take(X, [0, 2], axis=1))
with pytest.raises(ValueError, match=r"Only axis in \(0, 1\) is supported"):
xp.take(X, xp.asarray([0]), axis=2)
with pytest.raises(ValueError, match=r"Only X.ndim in \(1, 2\) is supported"):
xp.take(xp.asarray([[[0]]]), xp.asarray([0]), axis=0)
@pytest.mark.parametrize("is_array_api", [True, False])
def test_asarray_with_order(is_array_api):
"""Test _asarray_with_order passes along order for NumPy arrays."""
if is_array_api:
xp = pytest.importorskip("numpy.array_api")
else:
xp = numpy
X = xp.asarray([1.2, 3.4, 5.1])
X_new = _asarray_with_order(X, order="F")
X_new_np = numpy.asarray(X_new)
assert X_new_np.flags["F_CONTIGUOUS"]
def test_asarray_with_order_ignored():
"""Test _asarray_with_order ignores order for Generic ArrayAPI."""
xp = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
X = numpy.asarray([[1.2, 3.4, 5.1], [3.4, 5.5, 1.2]], order="C")
X = xp_.asarray(X)
X_new = _asarray_with_order(X, order="F", xp=xp_)
X_new_np = numpy.asarray(X_new)
assert X_new_np.flags["C_CONTIGUOUS"]
assert not X_new_np.flags["F_CONTIGUOUS"]
def test_convert_to_numpy_error():
"""Test convert to numpy errors for unsupported namespaces."""
xp = pytest.importorskip("numpy.array_api")
xp_ = _AdjustableNameAPITestWrapper(xp, "wrapped.array_api")
X = xp_.asarray([1.2, 3.4])
with pytest.raises(ValueError, match="Supported namespaces are:"):
_convert_to_numpy(X, xp=xp_)
class SimpleEstimator(BaseEstimator):
def fit(self, X, y=None):
self.X_ = X
self.n_features_ = X.shape[0]
return self
@pytest.mark.parametrize("array_namespace", ["numpy.array_api", "cupy.array_api"])
def test_convert_estimator_to_ndarray(array_namespace):
"""Convert estimator attributes to ndarray."""
xp = pytest.importorskip(array_namespace)
if array_namespace == "numpy.array_api":
converter = lambda array: numpy.asarray(array) # noqa
else: # pragma: no cover
converter = lambda array: array._array.get() # noqa
X = xp.asarray([[1.3, 4.5]])
est = SimpleEstimator().fit(X)
new_est = _estimator_with_converted_arrays(est, converter)
assert isinstance(new_est.X_, numpy.ndarray)
def test_convert_estimator_to_array_api():
"""Convert estimator attributes to ArrayAPI arrays."""
xp = pytest.importorskip("numpy.array_api")
X_np = numpy.asarray([[1.3, 4.5]])
est = SimpleEstimator().fit(X_np)
new_est = _estimator_with_converted_arrays(est, lambda array: xp.asarray(array))
assert hasattr(new_est.X_, "__array_namespace__")
|