1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
|
"""
Internal helpers
"""
from collections.abc import Callable
from functools import wraps
from inspect import signature
from types import ModuleType
from typing import TypeVar
_T = TypeVar("_T")
def get_xp(xp: ModuleType) -> Callable[[Callable[..., _T]], Callable[..., _T]]:
"""
Decorator to automatically replace xp with the corresponding array module.
Use like
import numpy as np
@get_xp(np)
def func(x, /, xp, kwarg=None):
return xp.func(x, kwarg=kwarg)
Note that xp must be a keyword argument and come after all non-keyword
arguments.
"""
def inner(f: Callable[..., _T], /) -> Callable[..., _T]:
@wraps(f)
def wrapped_f(*args: object, **kwargs: object) -> object:
return f(*args, xp=xp, **kwargs)
sig = signature(f)
new_sig = sig.replace(
parameters=[par for i, par in sig.parameters.items() if i != "xp"]
)
if wrapped_f.__doc__ is None:
wrapped_f.__doc__ = f"""\
Array API compatibility wrapper for {f.__name__}.
See the corresponding documentation in NumPy/CuPy and/or the array API
specification for more details.
"""
wrapped_f.__signature__ = new_sig # pyright: ignore[reportAttributeAccessIssue]
return wrapped_f # pyright: ignore[reportReturnType]
return inner
__all__ = ["get_xp"]
def __dir__() -> list[str]:
return __all__
|