1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
|
from importlib import import_module
import pytest
wrapped_libraries = ["numpy", "torch", "dask.array"]
all_libraries = wrapped_libraries
def import_(library, wrapper=False):
pytest.importorskip(library)
if wrapper:
if 'jax' in library:
# JAX v0.4.32 implements the array API directly in jax.numpy
# Older jax versions use jax.experimental.array_api
jax_numpy = import_module("jax.numpy")
if not hasattr(jax_numpy, "__array_api_version__"):
library = 'jax.experimental.array_api'
elif library in wrapped_libraries:
library = 'array_api_compat.' + library
return import_module(library)
def xfail(request: pytest.FixtureRequest, reason: str) -> None:
"""
XFAIL the currently running test.
Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902
"""
request.node.add_marker(pytest.mark.xfail(reason=reason))
|