File: _helpers.py

package info (click to toggle)
python-array-api-compat 1.11.2-1
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 708 kB
  • sloc: python: 3,954; sh: 16; makefile: 15
file content (31 lines) | stat: -rw-r--r-- 1,066 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from importlib import import_module

import pytest

wrapped_libraries = ["numpy", "torch", "dask.array"]
all_libraries = wrapped_libraries

def import_(library, wrapper=False):
    pytest.importorskip(library)
    if wrapper:
        if 'jax' in library:
            # JAX v0.4.32 implements the array API directly in jax.numpy
            # Older jax versions use jax.experimental.array_api
            jax_numpy = import_module("jax.numpy")
            if not hasattr(jax_numpy, "__array_api_version__"):
                library = 'jax.experimental.array_api'
        elif library in wrapped_libraries:
            library = 'array_api_compat.' + library

    return import_module(library)


def xfail(request: pytest.FixtureRequest, reason: str) -> None:
    """
    XFAIL the currently running test.

    Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
    halting it, so that it may result in a XPASS.
    xref https://github.com/pandas-dev/pandas/issues/38902
    """
    request.node.add_marker(pytest.mark.xfail(reason=reason))