File: linalg.py

package info (click to toggle)
scikit-learn 1.7.2%2Bdfsg-3
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 25,752 kB
  • sloc: python: 219,120; cpp: 5,790; ansic: 846; makefile: 191; javascript: 110
file content (143 lines) | stat: -rw-r--r-- 4,039 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
# pyright: reportAttributeAccessIssue=false
# pyright: reportUnknownArgumentType=false
# pyright: reportUnknownMemberType=false
# pyright: reportUnknownVariableType=false

from __future__ import annotations

import numpy as np

# intersection of `np.linalg.__all__` on numpy 1.22 and 2.2, minus `_linalg.__all__`
from numpy.linalg import (
    LinAlgError,
    cond,
    det,
    eig,
    eigvals,
    eigvalsh,
    inv,
    lstsq,
    matrix_power,
    multi_dot,
    norm,
    tensorinv,
    tensorsolve,
)

from .._internal import get_xp
from ..common import _linalg

# These functions are in both the main and linalg namespaces
from ._aliases import matmul, matrix_transpose, tensordot, vecdot  # noqa: F401
from ._typing import Array

cross = get_xp(np)(_linalg.cross)
outer = get_xp(np)(_linalg.outer)
EighResult = _linalg.EighResult
QRResult = _linalg.QRResult
SlogdetResult = _linalg.SlogdetResult
SVDResult = _linalg.SVDResult
eigh = get_xp(np)(_linalg.eigh)
qr = get_xp(np)(_linalg.qr)
slogdet = get_xp(np)(_linalg.slogdet)
svd = get_xp(np)(_linalg.svd)
cholesky = get_xp(np)(_linalg.cholesky)
matrix_rank = get_xp(np)(_linalg.matrix_rank)
pinv = get_xp(np)(_linalg.pinv)
matrix_norm = get_xp(np)(_linalg.matrix_norm)
svdvals = get_xp(np)(_linalg.svdvals)
diagonal = get_xp(np)(_linalg.diagonal)
trace = get_xp(np)(_linalg.trace)

# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
# of matrices. The np.linalg.solve behavior of allowing stacks of both
# matrices and vectors is ambiguous c.f.
# https://github.com/numpy/numpy/issues/15349 and
# https://github.com/data-apis/array-api/issues/285.

# To workaround this, the below is the code from np.linalg.solve except
# only calling solve1 in the exactly 1D case.


# This code is here instead of in common because it is numpy specific. Also
# note that CuPy's solve() does not currently support broadcasting (see
# https://github.com/cupy/cupy/blob/main/cupy/cublas.py#L43).
def solve(x1: Array, x2: Array, /) -> Array:
    try:
        from numpy.linalg._linalg import (
            _assert_stacked_2d,
            _assert_stacked_square,
            _commonType,
            _makearray,
            _raise_linalgerror_singular,
            isComplexType,
        )
    except ImportError:
        from numpy.linalg.linalg import (
            _assert_stacked_2d,
            _assert_stacked_square,
            _commonType,
            _makearray,
            _raise_linalgerror_singular,
            isComplexType,
        )
    from numpy.linalg import _umath_linalg

    x1, _ = _makearray(x1)
    _assert_stacked_2d(x1)
    _assert_stacked_square(x1)
    x2, wrap = _makearray(x2)
    t, result_t = _commonType(x1, x2)

    # This part is different from np.linalg.solve
    gufunc: np.ufunc
    if x2.ndim == 1:
        gufunc = _umath_linalg.solve1
    else:
        gufunc = _umath_linalg.solve

    # This does nothing currently but is left in because it will be relevant
    # when complex dtype support is added to the spec in 2022.
    signature = "DD->D" if isComplexType(t) else "dd->d"
    with np.errstate(
        call=_raise_linalgerror_singular,
        invalid="call",
        over="ignore",
        divide="ignore",
        under="ignore",
    ):
        r: Array = gufunc(x1, x2, signature=signature)

    return wrap(r.astype(result_t, copy=False))


# These functions are completely new here. If the library already has them
# (i.e., numpy 2.0), use the library version instead of our wrapper.
if hasattr(np.linalg, "vector_norm"):
    vector_norm = np.linalg.vector_norm
else:
    vector_norm = get_xp(np)(_linalg.vector_norm)


__all__ = [
    "LinAlgError",
    "cond",
    "det",
    "eig",
    "eigvals",
    "eigvalsh",
    "inv",
    "lstsq",
    "matrix_power",
    "multi_dot",
    "norm",
    "tensorinv",
    "tensorsolve",
]
__all__ += _linalg.__all__
__all__ += ["solve", "vector_norm"]


def __dir__() -> list[str]:
    return __all__