1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
|
# pyright: reportAttributeAccessIssue=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false
from __future__ import annotations
import numpy as np
# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
from numpy.linalg import (
LinAlgError,
cond,
det,
eig,
eigvals,
eigvalsh,
inv,
lstsq,
matrix_power,
multi_dot,
norm,
tensorinv,
tensorsolve,
)
from .._internal import get_xp
from ..common import _linalg
# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot # noqa: F401
from ._typing import Array
cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)
# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.
# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.
# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: Array, x2: Array, /) -> Array:
try:
from numpy.linalg._linalg import (
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
except ImportError:
from numpy.linalg.linalg import (
_assert_stacked_2d,
_assert_stacked_square,
_commonType,
_makearray,
_raise_linalgerror_singular,
isComplexType,
)
from numpy.linalg import _umath_linalg
x1, _ = _makearray(x1)
_assert_stacked_2d(x1)
_assert_stacked_square(x1)
x2, wrap = _makearray(x2)
t, result_t = _commonType(x1, x2)
# This part is different from np.linalg.solve
gufunc: np.ufunc
if x2.ndim == 1:
gufunc = _umath_linalg.solve1
else:
gufunc = _umath_linalg.solve
# This does nothing currently but is left in because it will be relevant
# when complex dtype support is added to the spec in 2022.
signature = "DD->D" if isComplexType(t) else "dd->d"
with np.errstate(
call=_raise_linalgerror_singular,
invalid="call",
over="ignore",
divide="ignore",
under="ignore",
):
r: Array = gufunc(x1, x2, signature=signature)
return wrap(r.astype(result_t, copy=False))
# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
vector_norm = np.linalg.vector_norm
else:
vector_norm = get_xp(np)(_linalg.vector_norm)
__all__ = [
"LinAlgError",
"cond",
"det",
"eig",
"eigvals",
"eigvalsh",
"inv",
"lstsq",
"matrix_power",
"multi_dot",
"norm",
"tensorinv",
"tensorsolve",
]
__all__ += _linalg.__all__
__all__ += ["solve", "vector_norm"]
def __dir__() -> list[str]:
return __all__
|