1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
|
import os
import warnings
from enum import Enum
from ._version import __version__, __version_tuple__ # noqa: F401
__array_api_version__ = "2024.12"
class _BackendType(Enum):
Numba = "Numba"
Finch = "Finch"
MLIR = "MLIR"
_ENV_VAR_NAME = "SPARSE_BACKEND"
class SparseFutureWarning(FutureWarning):
pass
if os.environ.get(_ENV_VAR_NAME, "") != "":
warnings.warn(
"Changing back-ends is a development feature, please do not rely on it in production.",
SparseFutureWarning,
stacklevel=1,
)
_backend_name = os.environ[_ENV_VAR_NAME]
else:
_backend_name = _BackendType.Numba.value
if _backend_name not in {v.value for v in _BackendType}:
warnings.warn(f"Invalid backend identifier: {_backend_name}. Selecting Numba backend.", UserWarning, stacklevel=1)
_BACKEND = _BackendType.Numba
else:
_BACKEND = _BackendType[_backend_name]
del _backend_name
if _BackendType.Finch == _BACKEND:
from sparse.finch_backend import * # noqa: F403
from sparse.finch_backend import __all__
elif _BackendType.MLIR == _BACKEND:
from sparse.mlir_backend import * # noqa: F403
from sparse.mlir_backend import __all__
else:
from sparse.numba_backend import * # noqa: F403
from sparse.numba_backend import ( # noqa: F401
__all__,
__array_namespace_info__,
_common,
_compressed,
_coo,
_dok,
_io,
_numba_extension,
_settings,
_slicing,
_sparse_array,
_umath,
_utils,
)
|